1use super::errors::{TypeError, TypeErrorWithLocation, TypeResult};
7use super::inference::TypeInferenceEngine;
8use super::*;
9use shape_ast::ast::{EnumDef, Expr, Item, Program, Span, Statement, TypeAnnotation};
10use std::collections::{HashMap, HashSet};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum TypeAnalysisMode {
14 FailFast,
15 RecoverAll,
16}
17
18pub struct TypeChecker {
19 inference_engine: TypeInferenceEngine,
21 errors: Vec<TypeErrorWithLocation>,
23 source: Option<String>,
25 filename: Option<String>,
27 enum_defs: HashMap<String, EnumDef>,
29 current_function_params: HashMap<String, shape_ast::ast::TypeAnnotation>,
31 analysis_mode: TypeAnalysisMode,
33}
34
35impl Default for TypeChecker {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl TypeChecker {
42 pub fn new() -> Self {
43 TypeChecker {
44 inference_engine: TypeInferenceEngine::new(),
45 errors: Vec::new(),
46 source: None,
47 filename: None,
48 enum_defs: HashMap::new(),
49 current_function_params: HashMap::new(),
50 analysis_mode: TypeAnalysisMode::FailFast,
51 }
52 }
53
54 pub fn with_source(mut self, source: String) -> Self {
56 self.source = Some(source);
57 self
58 }
59
60 pub fn with_filename(mut self, filename: String) -> Self {
62 self.filename = Some(filename);
63 self
64 }
65
66 pub fn with_known_bindings(mut self, names: &[String]) -> Self {
68 self.inference_engine.register_known_bindings(names);
69 self
70 }
71
72 pub fn with_analysis_mode(mut self, mode: TypeAnalysisMode) -> Self {
73 self.analysis_mode = mode;
74 self
75 }
76
77 pub fn check_program(
79 &mut self,
80 program: &Program,
81 ) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
82 self.errors.clear();
84 self.enum_defs.clear();
85
86 for item in &program.items {
88 if let Item::Enum(enum_def, _) = item {
89 self.enum_defs
90 .insert(enum_def.name.clone(), enum_def.clone());
91 }
92 }
93
94 let types = match self.analysis_mode {
95 TypeAnalysisMode::FailFast => match self.inference_engine.infer_program(program) {
96 Ok(types) => types,
97 Err(err) => {
98 self.add_inference_error(err);
99 return Err(self.errors.clone());
100 }
101 },
102 TypeAnalysisMode::RecoverAll => {
103 let (types, inference_errors) =
104 self.inference_engine.infer_program_best_effort(program);
105 for err in inference_errors {
106 self.add_inference_error(err);
107 }
108 types
109 }
110 };
111
112 self.check_items(&program.items);
114
115 self.check_expressions(&program.items);
117
118 self.prune_error_cascades();
119
120 if self.errors.is_empty() {
121 let semantic_types: HashMap<String, SemanticType> = types
123 .iter()
124 .filter_map(|(name, ty)| ty.to_semantic().map(|st| (name.clone(), st)))
125 .collect();
126
127 Ok(TypeCheckResult {
128 types,
129 semantic_types,
130 warnings: Vec::new(),
131 })
132 } else {
133 Err(self.errors.clone())
134 }
135 }
136
137 fn add_inference_error(&mut self, err: TypeError) {
138 let (line, col) = self.find_inference_error_position(&err);
139 self.add_error(err, line, col);
140 }
141
142 fn prune_error_cascades(&mut self) {
143 let has_specific_errors = self
144 .errors
145 .iter()
146 .any(|err| !matches!(err.error, TypeError::UnsolvedConstraints(_)));
147 if has_specific_errors {
148 self.errors
149 .retain(|err| !matches!(err.error, TypeError::UnsolvedConstraints(_)));
150 }
151
152 let mut seen = HashSet::new();
153 self.errors.retain(|err| {
154 let key = (err.line, err.column, err.error.to_string());
155 seen.insert(key)
156 });
157 }
158
159 fn find_inference_error_position(&self, error: &TypeError) -> (usize, usize) {
160 match error {
161 TypeError::UnknownProperty(_, property) => {
162 if let Some(span) = self
163 .inference_engine
164 .lookup_unknown_property_origin(property)
165 {
166 if let Some((line, col)) = self.span_to_line_col(span) {
167 return (line, col);
168 }
169 }
170 (0, 0)
171 }
172 TypeError::UndefinedVariable(name) => self
173 .inference_engine
174 .lookup_undefined_variable_origin(name)
175 .and_then(|span| self.span_to_line_col(span))
176 .unwrap_or((0, 0)),
177 TypeError::UnsolvedConstraints(constraints) => {
178 if let Some(span) = self
179 .inference_engine
180 .find_origin_for_unsolved_constraints(constraints)
181 {
182 if let Some((line, col)) = self.span_to_line_col(span) {
183 return (line, col);
184 }
185 }
186 if let Some(span) = self.inference_engine.find_any_constraint_origin() {
187 if let Some((line, col)) = self.span_to_line_col(span) {
188 return (line, col);
189 }
190 }
191 (0, 0)
192 }
193 TypeError::InvalidAssertion(_, _) => (0, 0),
194 TypeError::NonExhaustiveMatch { enum_name, .. } => self
195 .inference_engine
196 .lookup_non_exhaustive_match_origin(enum_name)
197 .and_then(|span| self.span_to_line_col(span))
198 .unwrap_or((0, 0)),
199 TypeError::GenericTypeError { symbol, .. } => {
200 if let Some(symbol) = symbol
201 && let Some(span) = self
202 .inference_engine
203 .lookup_callable_origin_for_name(symbol)
204 && let Some((line, col)) = self.span_to_line_col(span)
205 {
206 return (line, col);
207 }
208 if let Some(span) = self.inference_engine.find_any_constraint_origin() {
209 if let Some((line, col)) = self.span_to_line_col(span) {
210 return (line, col);
211 }
212 }
213 (0, 0)
214 }
215 _ => (0, 0),
216 }
217 }
218
219 fn span_to_line_col(&self, span: shape_ast::ast::Span) -> Option<(usize, usize)> {
220 let source = self.source.as_ref()?;
221 let start = span.start.min(source.len());
222 let prefix = &source[..start];
223 let line = prefix.bytes().filter(|b| *b == b'\n').count() + 1;
224 let line_start = prefix.rfind('\n').map(|idx| idx + 1).unwrap_or(0);
225 let column = prefix[line_start..].chars().count() + 1;
226 Some((line, column))
227 }
228
229 fn check_items(&mut self, items: &[Item]) {
231 for item in items {
232 self.check_item(item);
233 }
234 }
235
236 fn check_expressions(&mut self, items: &[Item]) {
238 for item in items {
239 self.check_item_expressions(item);
240 }
241 }
242
243 fn check_item_expressions(&mut self, item: &Item) {
245 if let Item::Function(func, _) = item {
246 self.current_function_params.clear();
248 for param in &func.params {
249 if let Some(type_ann) = ¶m.type_annotation {
250 for name in param.get_identifiers() {
252 self.current_function_params.insert(name, type_ann.clone());
253 }
254 }
255 }
256
257 for stmt in &func.body {
258 self.check_statement_expressions(stmt);
259 }
260
261 self.current_function_params.clear();
263 }
264 }
265
266 fn check_statement_expressions(&mut self, stmt: &Statement) {
268 match stmt {
269 Statement::Expression(expr, _) => self.check_expr(expr),
270 Statement::Return(Some(expr), _) => self.check_expr(expr),
271 Statement::VariableDecl(decl, _) => {
272 if let Some(init) = &decl.value {
273 self.check_expr(init);
274 }
275 }
276 Statement::If(if_stmt, _) => {
277 self.check_expr(&if_stmt.condition);
278 for stmt in &if_stmt.then_body {
279 self.check_statement_expressions(stmt);
280 }
281 if let Some(else_body) = &if_stmt.else_body {
282 for stmt in else_body {
283 self.check_statement_expressions(stmt);
284 }
285 }
286 }
287 Statement::While(while_loop, _) => {
288 self.check_expr(&while_loop.condition);
289 for stmt in &while_loop.body {
290 self.check_statement_expressions(stmt);
291 }
292 }
293 Statement::For(for_loop, _) => {
294 for stmt in &for_loop.body {
295 self.check_statement_expressions(stmt);
296 }
297 }
298 _ => {}
299 }
300 }
301
302 fn check_expr(&mut self, expr: &Expr) {
307 match expr {
308 Expr::Match(match_expr, _span) => {
309 self.check_expr(&match_expr.scrutinee);
311 for arm in &match_expr.arms {
312 if let Some(guard) = &arm.guard {
313 self.check_expr(guard);
314 }
315 self.check_expr(&arm.body);
316 }
317 }
318 Expr::BinaryOp { left, right, .. } => {
320 self.check_expr(left);
321 self.check_expr(right);
322 }
323 Expr::UnaryOp { operand, .. } => {
324 self.check_expr(operand);
325 }
326 Expr::Conditional {
327 condition,
328 then_expr,
329 else_expr,
330 ..
331 } => {
332 self.check_expr(condition);
333 self.check_expr(then_expr);
334 if let Some(else_e) = else_expr {
335 self.check_expr(else_e);
336 }
337 }
338 Expr::If(if_expr, _) => {
339 self.check_expr(&if_expr.condition);
340 self.check_expr(&if_expr.then_branch);
341 if let Some(else_branch) = &if_expr.else_branch {
342 self.check_expr(else_branch);
343 }
344 }
345 Expr::FunctionCall { args, .. } => {
346 for arg in args {
347 self.check_expr(arg);
348 }
349 }
350 Expr::MethodCall { receiver, args, .. } => {
351 self.check_expr(receiver);
352 for arg in args {
353 self.check_expr(arg);
354 }
355 }
356 Expr::Array(elems, _) => {
357 for elem in elems {
358 self.check_expr(elem);
359 }
360 }
361 Expr::PropertyAccess { object, .. } => {
362 self.check_expr(object);
363 }
364 Expr::IndexAccess {
365 object,
366 index,
367 end_index,
368 ..
369 } => {
370 self.check_expr(object);
371 self.check_expr(index);
372 if let Some(end) = end_index {
373 self.check_expr(end);
374 }
375 }
376 _ => {}
377 }
378 }
379
380 fn check_item(&mut self, item: &Item) {
386 match item {
387 Item::Function(func, span) => {
388 if func.return_type.is_some()
390 && !matches!(func.return_type.as_ref().unwrap(), TypeAnnotation::Void)
391 && !self.has_return_statement(&func.body)
392 {
393 let (line, col) = self.item_span_to_line_col(*span);
394 self.add_error(TypeError::MissingReturn(func.name.clone()), line, col);
395 }
396 }
397
398 Item::TypeAlias(alias, span) => {
399 if self.is_cyclic_type_alias(&alias.name, &alias.type_annotation) {
401 let (line, col) = self.item_span_to_line_col(*span);
402 self.add_error(TypeError::CyclicTypeAlias(alias.name.clone()), line, col);
403 }
404 }
405
406 Item::Interface(interface, span) => {
407 self.check_interface(interface, *span);
409 }
410
411 _ => {}
412 }
413 }
414
415 fn has_return_statement(&self, stmts: &[Statement]) -> bool {
417 for stmt in stmts {
418 match stmt {
419 Statement::Return(_, _) => return true,
420 Statement::If(if_stmt, _) => {
421 if let Some(else_body) = &if_stmt.else_body {
423 if self.has_return_statement(&if_stmt.then_body)
424 && self.has_return_statement(else_body)
425 {
426 return true;
427 }
428 }
429 }
430 Statement::While(while_loop, _) => {
431 if self.has_return_statement(&while_loop.body) {
432 return true;
434 }
435 }
436 Statement::For(for_loop, _) => {
437 if self.has_return_statement(&for_loop.body) {
438 return true;
440 }
441 }
442 _ => {}
443 }
444 }
445
446 false
447 }
448
449 fn is_cyclic_type_alias(&self, name: &str, ty: &TypeAnnotation) -> bool {
451 self.references_type(ty, name)
452 }
453
454 fn references_type(&self, ty: &TypeAnnotation, name: &str) -> bool {
456 match ty {
457 TypeAnnotation::Reference(ref_name) => ref_name == name,
458 TypeAnnotation::Array(elem) => self.references_type(elem, name),
459 TypeAnnotation::Tuple(elems) => {
460 elems.iter().any(|elem| self.references_type(elem, name))
461 }
462 TypeAnnotation::Object(fields) => fields
463 .iter()
464 .any(|field| self.references_type(&field.type_annotation, name)),
465 TypeAnnotation::Function { params, returns } => {
466 params
467 .iter()
468 .any(|param| self.references_type(¶m.type_annotation, name))
469 || self.references_type(returns, name)
470 }
471 TypeAnnotation::Union(types) => types.iter().any(|ty| self.references_type(ty, name)),
472 TypeAnnotation::Optional(ty) => self.references_type(ty, name),
473 TypeAnnotation::Generic { args, .. } => {
474 args.iter().any(|arg| self.references_type(arg, name))
475 }
476 _ => false,
477 }
478 }
479
480 fn check_interface(&mut self, interface: &shape_ast::ast::InterfaceDef, interface_span: Span) {
482 let mut seen_members = HashMap::new();
484
485 for (i, member) in interface.members.iter().enumerate() {
486 let member_name = match member {
487 shape_ast::ast::InterfaceMember::Property { name, .. } => name,
488 shape_ast::ast::InterfaceMember::Method { name, .. } => name,
489 shape_ast::ast::InterfaceMember::IndexSignature { .. } => continue,
490 };
491
492 if let Some(_prev_index) = seen_members.get(member_name) {
493 let (line, col) = self.item_span_to_line_col(interface_span);
494 self.add_error(
495 TypeError::InterfaceError(
496 interface.name.clone(),
497 format!("Duplicate member '{}'", member_name),
498 ),
499 line,
500 col,
501 );
502 } else {
503 seen_members.insert(member_name.clone(), i);
504 }
505 }
506 }
507
508 fn item_span_to_line_col(&self, span: Span) -> (usize, usize) {
509 self.span_to_line_col(span).unwrap_or((0, 0))
510 }
511
512 fn add_error(&mut self, error: TypeError, line: usize, column: usize) {
514 let mut err = TypeErrorWithLocation::new(error, line, column);
515
516 if let Some(filename) = &self.filename {
517 err = err.with_file(filename.clone());
518 }
519
520 if let Some(source) = &self.source {
521 if let Some(source_line) = source.lines().nth(line.saturating_sub(1)) {
523 err = err.with_source_line(source_line.to_string());
524 }
525 }
526
527 self.errors.push(err);
528 }
529
530 pub fn errors(&self) -> &[TypeErrorWithLocation] {
532 &self.errors
533 }
534
535 pub fn format_errors(&self) -> String {
537 self.errors
538 .iter()
539 .map(|err| err.format_with_source())
540 .collect::<Vec<_>>()
541 .join("\n")
542 }
543}
544
545pub fn analyze_program(
547 program: &Program,
548 source: Option<&str>,
549 filename: Option<&str>,
550 known_bindings: Option<&[String]>,
551) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
552 analyze_program_with_mode(
553 program,
554 source,
555 filename,
556 known_bindings,
557 TypeAnalysisMode::FailFast,
558 )
559}
560
561pub fn analyze_program_with_mode(
563 program: &Program,
564 source: Option<&str>,
565 filename: Option<&str>,
566 known_bindings: Option<&[String]>,
567 analysis_mode: TypeAnalysisMode,
568) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
569 let mut checker = TypeChecker::new();
570 if let Some(src) = source {
571 checker = checker.with_source(src.to_string());
572 }
573 if let Some(file) = filename {
574 checker = checker.with_filename(file.to_string());
575 }
576 if let Some(names) = known_bindings {
577 checker = checker.with_known_bindings(names);
578 }
579 checker = checker.with_analysis_mode(analysis_mode);
580 checker.check_program(program)
581}
582
583#[derive(Debug)]
585pub struct TypeCheckResult {
586 pub types: HashMap<String, Type>,
588 pub semantic_types: HashMap<String, SemanticType>,
590 pub warnings: Vec<TypeWarning>,
592}
593
594impl TypeCheckResult {
595 pub fn get_semantic_type(&self, name: &str) -> Option<&SemanticType> {
597 self.semantic_types.get(name)
598 }
599
600 pub fn fallible_functions(&self) -> Vec<&str> {
602 self.semantic_types
603 .iter()
604 .filter_map(|(name, ty)| {
605 if let SemanticType::Function(sig) = ty {
606 if sig.return_type.is_result() {
607 return Some(name.as_str());
608 }
609 }
610 None
611 })
612 .collect()
613 }
614}
615
616#[derive(Debug)]
618pub struct TypeWarning {
619 pub message: String,
620 pub line: usize,
621 pub column: usize,
622}
623
624pub fn type_of_expr(expr: &Expr, _env: &TypeEnvironment) -> TypeResult<Type> {
626 let mut engine = TypeInferenceEngine::new();
627 engine.infer_expr(expr)
628}
629
630pub fn quick_check(source: &str) -> Result<TypeCheckResult, String> {
632 use shape_ast::parser::parse_program;
633
634 let program = parse_program(source).map_err(|e| format!("Parse error: {}", e))?;
635
636 let mut checker = TypeChecker::new().with_source(source.to_string());
637
638 checker.check_program(&program).map_err(|errors| {
639 errors
640 .iter()
641 .map(|e| e.format_with_source())
642 .collect::<Vec<_>>()
643 .join("\n")
644 })
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650
651 #[test]
652 fn test_exhaustiveness_integration_non_exhaustive_match_produces_error() {
653 let source = r#"
656 enum Status { Active, Inactive, Pending }
657
658 function check(s: Status) {
659 return match s {
660 Status::Active => "yes"
661 };
662 }
663 "#;
664
665 let result = quick_check(source);
666
667 assert!(
670 result.is_err(),
671 "Expected error for non-exhaustive match, got: {:?}",
672 result
673 );
674 let err = result.unwrap_err();
675 assert!(
676 err.contains("NonExhaustive")
677 || err.contains("non-exhaustive")
678 || err.contains("missing"),
679 "Expected non-exhaustive match error, got: {}",
680 err
681 );
682 }
683
684 #[test]
685 fn test_exhaustiveness_integration_exhaustive_match_succeeds() {
686 let source = r#"
688 enum Status { Active, Inactive }
689
690 function check(s: Status) {
691 return match s {
692 Status::Active => "yes",
693 Status::Inactive => "no"
694 };
695 }
696 "#;
697
698 let result = quick_check(source);
699
700 if let Err(err) = &result {
703 assert!(
704 !err.contains("NonExhaustive") && !err.contains("non-exhaustive"),
705 "Should not have non-exhaustive error for exhaustive match, got: {}",
706 err
707 );
708 }
709 }
710
711 #[test]
712 fn test_exhaustiveness_integration_wildcard_makes_exhaustive() {
713 let source = r#"
715 enum Status { Active, Inactive, Pending }
716
717 function check(s: Status) {
718 return match s {
719 Status::Active => "yes",
720 _ => "other"
721 };
722 }
723 "#;
724
725 let result = quick_check(source);
726
727 if let Err(err) = &result {
729 assert!(
730 !err.contains("NonExhaustive") && !err.contains("non-exhaustive"),
731 "Wildcard should make match exhaustive, got: {}",
732 err
733 );
734 }
735 }
736
737 #[test]
738 fn test_undefined_variable_reports_identifier_position() {
739 use shape_ast::parser::parse_program;
740
741 let source = r#"
742let x = 1
743let y = duckdb.connect("duckdb://analytics.db")
744"#;
745
746 let program = parse_program(source).expect("program should parse");
747 let result = analyze_program(&program, Some(source), None, None);
748 let errors = result.expect_err("undefined variable should fail analysis");
749 let undef = errors
750 .iter()
751 .find(|e| matches!(&e.error, TypeError::UndefinedVariable(name) if name == "duckdb"))
752 .expect("missing undefined-variable error for duckdb");
753
754 assert_eq!(undef.line, 3);
755 assert_eq!(undef.column, 9);
756 }
757
758 #[test]
759 fn test_known_bindings_allow_extension_namespace_in_type_analysis() {
760 use shape_ast::parser::parse_program;
761
762 let source = r#"let conn = duckdb.connect("duckdb://analytics.db")"#;
763 let program = parse_program(source).expect("program should parse");
764 let known = vec!["duckdb".to_string()];
765
766 let result = analyze_program(&program, Some(source), None, Some(&known));
767 assert!(
768 result.is_ok(),
769 "known extension namespaces should not fail type analysis: {:?}",
770 result.err()
771 );
772 }
773}