Skip to main content

shape_runtime/type_system/
checker.rs

1//! Type Checker
2//!
3//! Performs type checking on Shape programs using the type inference engine
4//! and reports type errors with helpful messages.
5
6use super::errors::{TypeError, TypeErrorWithLocation, TypeResult};
7use super::inference::TypeInferenceEngine;
8use super::*;
9use shape_ast::ast::{EnumDef, Expr, Item, Program, Span, Statement, TypeAnnotation};
10use std::collections::{HashMap, HashSet};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum TypeAnalysisMode {
14    FailFast,
15    RecoverAll,
16}
17
18pub struct TypeChecker {
19    /// Type inference engine
20    inference_engine: TypeInferenceEngine,
21    /// Collected errors
22    errors: Vec<TypeErrorWithLocation>,
23    /// Source code for error reporting
24    source: Option<String>,
25    /// File name for error reporting
26    filename: Option<String>,
27    /// Enum definitions for resolving named types
28    enum_defs: HashMap<String, EnumDef>,
29    /// Current function's parameter types (name -> type annotation)
30    current_function_params: HashMap<String, shape_ast::ast::TypeAnnotation>,
31    /// Error emission behavior for semantic analysis.
32    analysis_mode: TypeAnalysisMode,
33}
34
35impl Default for TypeChecker {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl TypeChecker {
42    pub fn new() -> Self {
43        TypeChecker {
44            inference_engine: TypeInferenceEngine::new(),
45            errors: Vec::new(),
46            source: None,
47            filename: None,
48            enum_defs: HashMap::new(),
49            current_function_params: HashMap::new(),
50            analysis_mode: TypeAnalysisMode::FailFast,
51        }
52    }
53
54    /// Set source code for error reporting
55    pub fn with_source(mut self, source: String) -> Self {
56        self.source = Some(source);
57        self
58    }
59
60    /// Set filename for error reporting
61    pub fn with_filename(mut self, filename: String) -> Self {
62        self.filename = Some(filename);
63        self
64    }
65
66    /// Register host-provided root-scope bindings (e.g. extension module namespaces).
67    pub fn with_known_bindings(mut self, names: &[String]) -> Self {
68        self.inference_engine.register_known_bindings(names);
69        self
70    }
71
72    pub fn with_analysis_mode(mut self, mode: TypeAnalysisMode) -> Self {
73        self.analysis_mode = mode;
74        self
75    }
76
77    /// Type check a complete program
78    pub fn check_program(
79        &mut self,
80        program: &Program,
81    ) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
82        // Clear previous errors
83        self.errors.clear();
84        self.enum_defs.clear();
85
86        // Collect enum definitions for type resolution
87        for item in &program.items {
88            if let Item::Enum(enum_def, _) = item {
89                self.enum_defs
90                    .insert(enum_def.name.clone(), enum_def.clone());
91            }
92        }
93
94        let types = match self.analysis_mode {
95            TypeAnalysisMode::FailFast => match self.inference_engine.infer_program(program) {
96                Ok(types) => types,
97                Err(err) => {
98                    self.add_inference_error(err);
99                    return Err(self.errors.clone());
100                }
101            },
102            TypeAnalysisMode::RecoverAll => {
103                let (types, inference_errors) =
104                    self.inference_engine.infer_program_best_effort(program);
105                for err in inference_errors {
106                    self.add_inference_error(err);
107                }
108                types
109            }
110        };
111
112        // Perform additional type checking
113        self.check_items(&program.items);
114
115        // Check exhaustiveness of match expressions
116        self.check_expressions(&program.items);
117
118        self.prune_error_cascades();
119
120        if self.errors.is_empty() {
121            // Convert inference types to semantic types
122            let semantic_types: HashMap<String, SemanticType> = types
123                .iter()
124                .filter_map(|(name, ty)| ty.to_semantic().map(|st| (name.clone(), st)))
125                .collect();
126
127            Ok(TypeCheckResult {
128                types,
129                semantic_types,
130                warnings: Vec::new(),
131            })
132        } else {
133            Err(self.errors.clone())
134        }
135    }
136
137    fn add_inference_error(&mut self, err: TypeError) {
138        let (line, col) = self.find_inference_error_position(&err);
139        self.add_error(err, line, col);
140    }
141
142    fn prune_error_cascades(&mut self) {
143        let has_specific_errors = self
144            .errors
145            .iter()
146            .any(|err| !matches!(err.error, TypeError::UnsolvedConstraints(_)));
147        if has_specific_errors {
148            self.errors
149                .retain(|err| !matches!(err.error, TypeError::UnsolvedConstraints(_)));
150        }
151
152        let mut seen = HashSet::new();
153        self.errors.retain(|err| {
154            let key = (err.line, err.column, err.error.to_string());
155            seen.insert(key)
156        });
157    }
158
159    fn find_inference_error_position(&self, error: &TypeError) -> (usize, usize) {
160        match error {
161            TypeError::UnknownProperty(_, property) => {
162                if let Some(span) = self
163                    .inference_engine
164                    .lookup_unknown_property_origin(property)
165                {
166                    if let Some((line, col)) = self.span_to_line_col(span) {
167                        return (line, col);
168                    }
169                }
170                (0, 0)
171            }
172            TypeError::UndefinedVariable(name) => self
173                .inference_engine
174                .lookup_undefined_variable_origin(name)
175                .and_then(|span| self.span_to_line_col(span))
176                .unwrap_or((0, 0)),
177            TypeError::UnsolvedConstraints(constraints) => {
178                if let Some(span) = self
179                    .inference_engine
180                    .find_origin_for_unsolved_constraints(constraints)
181                {
182                    if let Some((line, col)) = self.span_to_line_col(span) {
183                        return (line, col);
184                    }
185                }
186                if let Some(span) = self.inference_engine.find_any_constraint_origin() {
187                    if let Some((line, col)) = self.span_to_line_col(span) {
188                        return (line, col);
189                    }
190                }
191                (0, 0)
192            }
193            TypeError::InvalidAssertion(_, _) => (0, 0),
194            TypeError::NonExhaustiveMatch { enum_name, .. } => self
195                .inference_engine
196                .lookup_non_exhaustive_match_origin(enum_name)
197                .and_then(|span| self.span_to_line_col(span))
198                .unwrap_or((0, 0)),
199            TypeError::GenericTypeError { symbol, .. } => {
200                if let Some(symbol) = symbol
201                    && let Some(span) = self
202                        .inference_engine
203                        .lookup_callable_origin_for_name(symbol)
204                    && let Some((line, col)) = self.span_to_line_col(span)
205                {
206                    return (line, col);
207                }
208                if let Some(span) = self.inference_engine.find_any_constraint_origin() {
209                    if let Some((line, col)) = self.span_to_line_col(span) {
210                        return (line, col);
211                    }
212                }
213                (0, 0)
214            }
215            _ => (0, 0),
216        }
217    }
218
219    fn span_to_line_col(&self, span: shape_ast::ast::Span) -> Option<(usize, usize)> {
220        let source = self.source.as_ref()?;
221        let start = span.start.min(source.len());
222        let prefix = &source[..start];
223        let line = prefix.bytes().filter(|b| *b == b'\n').count() + 1;
224        let line_start = prefix.rfind('\n').map(|idx| idx + 1).unwrap_or(0);
225        let column = prefix[line_start..].chars().count() + 1;
226        Some((line, column))
227    }
228
229    /// Check all items in the program
230    fn check_items(&mut self, items: &[Item]) {
231        for item in items {
232            self.check_item(item);
233        }
234    }
235
236    /// Check all expressions in the program for exhaustiveness
237    fn check_expressions(&mut self, items: &[Item]) {
238        for item in items {
239            self.check_item_expressions(item);
240        }
241    }
242
243    /// Check expressions within an item
244    fn check_item_expressions(&mut self, item: &Item) {
245        if let Item::Function(func, _) = item {
246            // Set up function parameter context for type resolution
247            self.current_function_params.clear();
248            for param in &func.params {
249                if let Some(type_ann) = &param.type_annotation {
250                    // Insert all identifiers from the pattern
251                    for name in param.get_identifiers() {
252                        self.current_function_params.insert(name, type_ann.clone());
253                    }
254                }
255            }
256
257            for stmt in &func.body {
258                self.check_statement_expressions(stmt);
259            }
260
261            // Clear function parameter context
262            self.current_function_params.clear();
263        }
264    }
265
266    /// Check expressions within a statement
267    fn check_statement_expressions(&mut self, stmt: &Statement) {
268        match stmt {
269            Statement::Expression(expr, _) => self.check_expr(expr),
270            Statement::Return(Some(expr), _) => self.check_expr(expr),
271            Statement::VariableDecl(decl, _) => {
272                if let Some(init) = &decl.value {
273                    self.check_expr(init);
274                }
275            }
276            Statement::If(if_stmt, _) => {
277                self.check_expr(&if_stmt.condition);
278                for stmt in &if_stmt.then_body {
279                    self.check_statement_expressions(stmt);
280                }
281                if let Some(else_body) = &if_stmt.else_body {
282                    for stmt in else_body {
283                        self.check_statement_expressions(stmt);
284                    }
285                }
286            }
287            Statement::While(while_loop, _) => {
288                self.check_expr(&while_loop.condition);
289                for stmt in &while_loop.body {
290                    self.check_statement_expressions(stmt);
291                }
292            }
293            Statement::For(for_loop, _) => {
294                for stmt in &for_loop.body {
295                    self.check_statement_expressions(stmt);
296                }
297            }
298            _ => {}
299        }
300    }
301
302    /// Check a single expression for type issues
303    ///
304    /// Note: Exhaustiveness checking is now handled by the inference engine
305    /// during `infer_expr()`, so we don't need to check it here.
306    fn check_expr(&mut self, expr: &Expr) {
307        match expr {
308            Expr::Match(match_expr, _span) => {
309                // Exhaustiveness is checked by inference engine - just check sub-expressions
310                self.check_expr(&match_expr.scrutinee);
311                for arm in &match_expr.arms {
312                    if let Some(guard) = &arm.guard {
313                        self.check_expr(guard);
314                    }
315                    self.check_expr(&arm.body);
316                }
317            }
318            // Recursively check sub-expressions
319            Expr::BinaryOp { left, right, .. } => {
320                self.check_expr(left);
321                self.check_expr(right);
322            }
323            Expr::UnaryOp { operand, .. } => {
324                self.check_expr(operand);
325            }
326            Expr::Conditional {
327                condition,
328                then_expr,
329                else_expr,
330                ..
331            } => {
332                self.check_expr(condition);
333                self.check_expr(then_expr);
334                if let Some(else_e) = else_expr {
335                    self.check_expr(else_e);
336                }
337            }
338            Expr::If(if_expr, _) => {
339                self.check_expr(&if_expr.condition);
340                self.check_expr(&if_expr.then_branch);
341                if let Some(else_branch) = &if_expr.else_branch {
342                    self.check_expr(else_branch);
343                }
344            }
345            Expr::FunctionCall { args, .. } => {
346                for arg in args {
347                    self.check_expr(arg);
348                }
349            }
350            Expr::MethodCall { receiver, args, .. } => {
351                self.check_expr(receiver);
352                for arg in args {
353                    self.check_expr(arg);
354                }
355            }
356            Expr::Array(elems, _) => {
357                for elem in elems {
358                    self.check_expr(elem);
359                }
360            }
361            Expr::PropertyAccess { object, .. } => {
362                self.check_expr(object);
363            }
364            Expr::IndexAccess {
365                object,
366                index,
367                end_index,
368                ..
369            } => {
370                self.check_expr(object);
371                self.check_expr(index);
372                if let Some(end) = end_index {
373                    self.check_expr(end);
374                }
375            }
376            _ => {}
377        }
378    }
379
380    // Note: check_match_exhaustiveness, resolve_named_to_enum, and span_to_location
381    // were removed as exhaustiveness checking is now handled by the inference engine
382    // in TypeInferenceEngine::infer_expr() for Match expressions.
383
384    /// Check a single item
385    fn check_item(&mut self, item: &Item) {
386        match item {
387            Item::Function(func, span) => {
388                // Check for missing return statements
389                if func.return_type.is_some()
390                    && !matches!(func.return_type.as_ref().unwrap(), TypeAnnotation::Void)
391                    && !self.has_return_statement(&func.body)
392                {
393                    let (line, col) = self.item_span_to_line_col(*span);
394                    self.add_error(TypeError::MissingReturn(func.name.clone()), line, col);
395                }
396            }
397
398            Item::TypeAlias(alias, span) => {
399                // Check for cyclic type aliases
400                if self.is_cyclic_type_alias(&alias.name, &alias.type_annotation) {
401                    let (line, col) = self.item_span_to_line_col(*span);
402                    self.add_error(TypeError::CyclicTypeAlias(alias.name.clone()), line, col);
403                }
404            }
405
406            Item::Interface(interface, span) => {
407                // Validate interface definition
408                self.check_interface(interface, *span);
409            }
410
411            _ => {}
412        }
413    }
414
415    /// Check if statements contain a return statement
416    fn has_return_statement(&self, stmts: &[Statement]) -> bool {
417        for stmt in stmts {
418            match stmt {
419                Statement::Return(_, _) => return true,
420                Statement::If(if_stmt, _) => {
421                    // Both branches must have returns
422                    if let Some(else_body) = &if_stmt.else_body {
423                        if self.has_return_statement(&if_stmt.then_body)
424                            && self.has_return_statement(else_body)
425                        {
426                            return true;
427                        }
428                    }
429                }
430                Statement::While(while_loop, _) => {
431                    if self.has_return_statement(&while_loop.body) {
432                        // Note: This is conservative - while loop might not execute
433                        return true;
434                    }
435                }
436                Statement::For(for_loop, _) => {
437                    if self.has_return_statement(&for_loop.body) {
438                        // Note: This is conservative - for loop might not execute
439                        return true;
440                    }
441                }
442                _ => {}
443            }
444        }
445
446        false
447    }
448
449    /// Check for cyclic type aliases
450    fn is_cyclic_type_alias(&self, name: &str, ty: &TypeAnnotation) -> bool {
451        self.references_type(ty, name)
452    }
453
454    /// Check if a type annotation references a specific type name
455    fn references_type(&self, ty: &TypeAnnotation, name: &str) -> bool {
456        match ty {
457            TypeAnnotation::Reference(ref_name) => ref_name == name,
458            TypeAnnotation::Array(elem) => self.references_type(elem, name),
459            TypeAnnotation::Tuple(elems) => {
460                elems.iter().any(|elem| self.references_type(elem, name))
461            }
462            TypeAnnotation::Object(fields) => fields
463                .iter()
464                .any(|field| self.references_type(&field.type_annotation, name)),
465            TypeAnnotation::Function { params, returns } => {
466                params
467                    .iter()
468                    .any(|param| self.references_type(&param.type_annotation, name))
469                    || self.references_type(returns, name)
470            }
471            TypeAnnotation::Union(types) => types.iter().any(|ty| self.references_type(ty, name)),
472            TypeAnnotation::Generic { args, .. } => {
473                args.iter().any(|arg| self.references_type(arg, name))
474            }
475            _ => false,
476        }
477    }
478
479    /// Check interface definition
480    fn check_interface(&mut self, interface: &shape_ast::ast::InterfaceDef, interface_span: Span) {
481        // Check for duplicate members
482        let mut seen_members = HashMap::new();
483
484        for (i, member) in interface.members.iter().enumerate() {
485            let member_name = match member {
486                shape_ast::ast::InterfaceMember::Property { name, .. } => name,
487                shape_ast::ast::InterfaceMember::Method { name, .. } => name,
488                shape_ast::ast::InterfaceMember::IndexSignature { .. } => continue,
489            };
490
491            if let Some(_prev_index) = seen_members.get(member_name) {
492                let (line, col) = self.item_span_to_line_col(interface_span);
493                self.add_error(
494                    TypeError::InterfaceError(
495                        interface.name.clone(),
496                        format!("Duplicate member '{}'", member_name),
497                    ),
498                    line,
499                    col,
500                );
501            } else {
502                seen_members.insert(member_name.clone(), i);
503            }
504        }
505    }
506
507    fn item_span_to_line_col(&self, span: Span) -> (usize, usize) {
508        self.span_to_line_col(span).unwrap_or((0, 0))
509    }
510
511    /// Add an error with location information
512    fn add_error(&mut self, error: TypeError, line: usize, column: usize) {
513        let mut err = TypeErrorWithLocation::new(error, line, column);
514
515        if let Some(filename) = &self.filename {
516            err = err.with_file(filename.clone());
517        }
518
519        if let Some(source) = &self.source {
520            // Extract the source line
521            if let Some(source_line) = source.lines().nth(line.saturating_sub(1)) {
522                err = err.with_source_line(source_line.to_string());
523            }
524        }
525
526        self.errors.push(err);
527    }
528
529    /// Get all collected errors
530    pub fn errors(&self) -> &[TypeErrorWithLocation] {
531        &self.errors
532    }
533
534    /// Format all errors for display
535    pub fn format_errors(&self) -> String {
536        self.errors
537            .iter()
538            .map(|err| err.format_with_source())
539            .collect::<Vec<_>>()
540            .join("\n")
541    }
542}
543
544/// Shared single-entry type analysis used by compiler and LSP.
545pub fn analyze_program(
546    program: &Program,
547    source: Option<&str>,
548    filename: Option<&str>,
549    known_bindings: Option<&[String]>,
550) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
551    analyze_program_with_mode(
552        program,
553        source,
554        filename,
555        known_bindings,
556        TypeAnalysisMode::FailFast,
557    )
558}
559
560/// Shared type analysis with explicit recovery behavior.
561pub fn analyze_program_with_mode(
562    program: &Program,
563    source: Option<&str>,
564    filename: Option<&str>,
565    known_bindings: Option<&[String]>,
566    analysis_mode: TypeAnalysisMode,
567) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
568    let mut checker = TypeChecker::new();
569    if let Some(src) = source {
570        checker = checker.with_source(src.to_string());
571    }
572    if let Some(file) = filename {
573        checker = checker.with_filename(file.to_string());
574    }
575    if let Some(names) = known_bindings {
576        checker = checker.with_known_bindings(names);
577    }
578    checker = checker.with_analysis_mode(analysis_mode);
579    checker.check_program(program)
580}
581
582/// Result of type checking
583#[derive(Debug)]
584pub struct TypeCheckResult {
585    /// Inferred types for all declarations (inference-level types)
586    pub types: HashMap<String, Type>,
587    /// Semantic types for all declarations (user-facing types)
588    pub semantic_types: HashMap<String, SemanticType>,
589    /// Type warnings (non-fatal issues)
590    pub warnings: Vec<TypeWarning>,
591}
592
593impl TypeCheckResult {
594    /// Get the semantic type for a declaration
595    pub fn get_semantic_type(&self, name: &str) -> Option<&SemanticType> {
596        self.semantic_types.get(name)
597    }
598
599    /// Get all function declarations that are fallible (return Result)
600    pub fn fallible_functions(&self) -> Vec<&str> {
601        self.semantic_types
602            .iter()
603            .filter_map(|(name, ty)| {
604                if let SemanticType::Function(sig) = ty {
605                    if sig.return_type.is_result() {
606                        return Some(name.as_str());
607                    }
608                }
609                None
610            })
611            .collect()
612    }
613}
614
615/// Type warning for non-fatal issues
616#[derive(Debug)]
617pub struct TypeWarning {
618    pub message: String,
619    pub line: usize,
620    pub column: usize,
621}
622
623/// Type check an expression and return its type
624pub fn type_of_expr(expr: &Expr, _env: &TypeEnvironment) -> TypeResult<Type> {
625    let mut engine = TypeInferenceEngine::new();
626    engine.infer_expr(expr)
627}
628
629/// Quick type check for REPL and testing
630pub fn quick_check(source: &str) -> Result<TypeCheckResult, String> {
631    use shape_ast::parser::parse_program;
632
633    let program = parse_program(source).map_err(|e| format!("Parse error: {}", e))?;
634
635    let mut checker = TypeChecker::new().with_source(source.to_string());
636
637    checker.check_program(&program).map_err(|errors| {
638        errors
639            .iter()
640            .map(|e| e.format_with_source())
641            .collect::<Vec<_>>()
642            .join("\n")
643    })
644}
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649
650    #[test]
651    fn test_exhaustiveness_integration_non_exhaustive_match_produces_error() {
652        // This test proves exhaustiveness checking is connected to the compiler pipeline.
653        // A match on an enum that doesn't cover all variants should produce an error.
654        let source = r#"
655            enum Status { Active, Inactive, Pending }
656
657            function check(s: Status) {
658                return match s {
659                    Status::Active => "yes"
660                };
661            }
662        "#;
663
664        let result = quick_check(source);
665
666        // The match is non-exhaustive (missing Inactive and Pending)
667        // so we expect an error
668        assert!(
669            result.is_err(),
670            "Expected error for non-exhaustive match, got: {:?}",
671            result
672        );
673        let err = result.unwrap_err();
674        assert!(
675            err.contains("NonExhaustive")
676                || err.contains("non-exhaustive")
677                || err.contains("missing"),
678            "Expected non-exhaustive match error, got: {}",
679            err
680        );
681    }
682
683    #[test]
684    fn test_exhaustiveness_integration_exhaustive_match_succeeds() {
685        // A match that covers all variants should succeed
686        let source = r#"
687            enum Status { Active, Inactive }
688
689            function check(s: Status) {
690                return match s {
691                    Status::Active => "yes",
692                    Status::Inactive => "no"
693                };
694            }
695        "#;
696
697        let result = quick_check(source);
698
699        // The match is exhaustive, so no error expected from exhaustiveness
700        // (there might be other errors, but not NonExhaustiveMatch)
701        if let Err(err) = &result {
702            assert!(
703                !err.contains("NonExhaustive") && !err.contains("non-exhaustive"),
704                "Should not have non-exhaustive error for exhaustive match, got: {}",
705                err
706            );
707        }
708    }
709
710    #[test]
711    fn test_exhaustiveness_integration_wildcard_makes_exhaustive() {
712        // A match with wildcard pattern should be trivially exhaustive
713        let source = r#"
714            enum Status { Active, Inactive, Pending }
715
716            function check(s: Status) {
717                return match s {
718                    Status::Active => "yes",
719                    _ => "other"
720                };
721            }
722        "#;
723
724        let result = quick_check(source);
725
726        // The wildcard makes it exhaustive
727        if let Err(err) = &result {
728            assert!(
729                !err.contains("NonExhaustive") && !err.contains("non-exhaustive"),
730                "Wildcard should make match exhaustive, got: {}",
731                err
732            );
733        }
734    }
735
736    #[test]
737    fn test_undefined_variable_reports_identifier_position() {
738        use shape_ast::parser::parse_program;
739
740        let source = r#"
741let x = 1
742let y = duckdb.connect("duckdb://analytics.db")
743"#;
744
745        let program = parse_program(source).expect("program should parse");
746        let result = analyze_program(&program, Some(source), None, None);
747        let errors = result.expect_err("undefined variable should fail analysis");
748        let undef = errors
749            .iter()
750            .find(|e| matches!(&e.error, TypeError::UndefinedVariable(name) if name == "duckdb"))
751            .expect("missing undefined-variable error for duckdb");
752
753        assert_eq!(undef.line, 3);
754        assert_eq!(undef.column, 9);
755    }
756
757    #[test]
758    fn test_known_bindings_allow_extension_namespace_in_type_analysis() {
759        use shape_ast::parser::parse_program;
760
761        let source = r#"let conn = duckdb.connect("duckdb://analytics.db")"#;
762        let program = parse_program(source).expect("program should parse");
763        let known = vec!["duckdb".to_string()];
764
765        let result = analyze_program(&program, Some(source), None, Some(&known));
766        assert!(
767            result.is_ok(),
768            "known extension namespaces should not fail type analysis: {:?}",
769            result.err()
770        );
771    }
772}