1use super::errors::{TypeError, TypeErrorWithLocation, TypeResult};
7use super::inference::TypeInferenceEngine;
8use super::*;
9use shape_ast::ast::{EnumDef, Expr, Item, Program, Span, Statement, TypeAnnotation};
10use std::collections::{HashMap, HashSet};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum TypeAnalysisMode {
14 FailFast,
15 RecoverAll,
16}
17
18pub struct TypeChecker {
19 inference_engine: TypeInferenceEngine,
21 errors: Vec<TypeErrorWithLocation>,
23 source: Option<String>,
25 filename: Option<String>,
27 enum_defs: HashMap<String, EnumDef>,
29 current_function_params: HashMap<String, shape_ast::ast::TypeAnnotation>,
31 analysis_mode: TypeAnalysisMode,
33}
34
35impl Default for TypeChecker {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl TypeChecker {
42 pub fn new() -> Self {
43 TypeChecker {
44 inference_engine: TypeInferenceEngine::new(),
45 errors: Vec::new(),
46 source: None,
47 filename: None,
48 enum_defs: HashMap::new(),
49 current_function_params: HashMap::new(),
50 analysis_mode: TypeAnalysisMode::FailFast,
51 }
52 }
53
54 pub fn with_source(mut self, source: String) -> Self {
56 self.source = Some(source);
57 self
58 }
59
60 pub fn with_filename(mut self, filename: String) -> Self {
62 self.filename = Some(filename);
63 self
64 }
65
66 pub fn with_known_bindings(mut self, names: &[String]) -> Self {
68 self.inference_engine.register_known_bindings(names);
69 self
70 }
71
72 pub fn with_analysis_mode(mut self, mode: TypeAnalysisMode) -> Self {
73 self.analysis_mode = mode;
74 self
75 }
76
77 pub fn check_program(
79 &mut self,
80 program: &Program,
81 ) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
82 self.errors.clear();
84 self.enum_defs.clear();
85
86 for item in &program.items {
88 if let Item::Enum(enum_def, _) = item {
89 self.enum_defs
90 .insert(enum_def.name.clone(), enum_def.clone());
91 }
92 }
93
94 let types = match self.analysis_mode {
95 TypeAnalysisMode::FailFast => match self.inference_engine.infer_program(program) {
96 Ok(types) => types,
97 Err(err) => {
98 self.add_inference_error(err);
99 return Err(self.errors.clone());
100 }
101 },
102 TypeAnalysisMode::RecoverAll => {
103 let (types, inference_errors) =
104 self.inference_engine.infer_program_best_effort(program);
105 for err in inference_errors {
106 self.add_inference_error(err);
107 }
108 types
109 }
110 };
111
112 self.check_items(&program.items);
114
115 self.check_expressions(&program.items);
117
118 self.prune_error_cascades();
119
120 if self.errors.is_empty() {
121 let semantic_types: HashMap<String, SemanticType> = types
123 .iter()
124 .filter_map(|(name, ty)| ty.to_semantic().map(|st| (name.clone(), st)))
125 .collect();
126
127 Ok(TypeCheckResult {
128 types,
129 semantic_types,
130 warnings: Vec::new(),
131 })
132 } else {
133 Err(self.errors.clone())
134 }
135 }
136
137 fn add_inference_error(&mut self, err: TypeError) {
138 let (line, col) = self.find_inference_error_position(&err);
139 self.add_error(err, line, col);
140 }
141
142 fn prune_error_cascades(&mut self) {
143 let has_specific_errors = self
144 .errors
145 .iter()
146 .any(|err| !matches!(err.error, TypeError::UnsolvedConstraints(_)));
147 if has_specific_errors {
148 self.errors
149 .retain(|err| !matches!(err.error, TypeError::UnsolvedConstraints(_)));
150 }
151
152 let mut seen = HashSet::new();
153 self.errors.retain(|err| {
154 let key = (err.line, err.column, err.error.to_string());
155 seen.insert(key)
156 });
157 }
158
159 fn find_inference_error_position(&self, error: &TypeError) -> (usize, usize) {
160 match error {
161 TypeError::UnknownProperty(_, property) => {
162 if let Some(span) = self
163 .inference_engine
164 .lookup_unknown_property_origin(property)
165 {
166 if let Some((line, col)) = self.span_to_line_col(span) {
167 return (line, col);
168 }
169 }
170 (0, 0)
171 }
172 TypeError::UndefinedVariable(name) => self
173 .inference_engine
174 .lookup_undefined_variable_origin(name)
175 .and_then(|span| self.span_to_line_col(span))
176 .unwrap_or((0, 0)),
177 TypeError::UnsolvedConstraints(constraints) => {
178 if let Some(span) = self
179 .inference_engine
180 .find_origin_for_unsolved_constraints(constraints)
181 {
182 if let Some((line, col)) = self.span_to_line_col(span) {
183 return (line, col);
184 }
185 }
186 if let Some(span) = self.inference_engine.find_any_constraint_origin() {
187 if let Some((line, col)) = self.span_to_line_col(span) {
188 return (line, col);
189 }
190 }
191 (0, 0)
192 }
193 TypeError::InvalidAssertion(_, _) => (0, 0),
194 TypeError::NonExhaustiveMatch { enum_name, .. } => self
195 .inference_engine
196 .lookup_non_exhaustive_match_origin(enum_name)
197 .and_then(|span| self.span_to_line_col(span))
198 .unwrap_or((0, 0)),
199 TypeError::GenericTypeError { symbol, .. } => {
200 if let Some(symbol) = symbol
201 && let Some(span) = self
202 .inference_engine
203 .lookup_callable_origin_for_name(symbol)
204 && let Some((line, col)) = self.span_to_line_col(span)
205 {
206 return (line, col);
207 }
208 if let Some(span) = self.inference_engine.find_any_constraint_origin() {
209 if let Some((line, col)) = self.span_to_line_col(span) {
210 return (line, col);
211 }
212 }
213 (0, 0)
214 }
215 _ => (0, 0),
216 }
217 }
218
219 fn span_to_line_col(&self, span: shape_ast::ast::Span) -> Option<(usize, usize)> {
220 let source = self.source.as_ref()?;
221 let start = span.start.min(source.len());
222 let prefix = &source[..start];
223 let line = prefix.bytes().filter(|b| *b == b'\n').count() + 1;
224 let line_start = prefix.rfind('\n').map(|idx| idx + 1).unwrap_or(0);
225 let column = prefix[line_start..].chars().count() + 1;
226 Some((line, column))
227 }
228
229 fn check_items(&mut self, items: &[Item]) {
231 for item in items {
232 self.check_item(item);
233 }
234 }
235
236 fn check_expressions(&mut self, items: &[Item]) {
238 for item in items {
239 self.check_item_expressions(item);
240 }
241 }
242
243 fn check_item_expressions(&mut self, item: &Item) {
245 if let Item::Function(func, _) = item {
246 self.current_function_params.clear();
248 for param in &func.params {
249 if let Some(type_ann) = ¶m.type_annotation {
250 for name in param.get_identifiers() {
252 self.current_function_params.insert(name, type_ann.clone());
253 }
254 }
255 }
256
257 for stmt in &func.body {
258 self.check_statement_expressions(stmt);
259 }
260
261 self.current_function_params.clear();
263 }
264 }
265
266 fn check_statement_expressions(&mut self, stmt: &Statement) {
268 match stmt {
269 Statement::Expression(expr, _) => self.check_expr(expr),
270 Statement::Return(Some(expr), _) => self.check_expr(expr),
271 Statement::VariableDecl(decl, _) => {
272 if let Some(init) = &decl.value {
273 self.check_expr(init);
274 }
275 }
276 Statement::If(if_stmt, _) => {
277 self.check_expr(&if_stmt.condition);
278 for stmt in &if_stmt.then_body {
279 self.check_statement_expressions(stmt);
280 }
281 if let Some(else_body) = &if_stmt.else_body {
282 for stmt in else_body {
283 self.check_statement_expressions(stmt);
284 }
285 }
286 }
287 Statement::While(while_loop, _) => {
288 self.check_expr(&while_loop.condition);
289 for stmt in &while_loop.body {
290 self.check_statement_expressions(stmt);
291 }
292 }
293 Statement::For(for_loop, _) => {
294 for stmt in &for_loop.body {
295 self.check_statement_expressions(stmt);
296 }
297 }
298 _ => {}
299 }
300 }
301
302 fn check_expr(&mut self, expr: &Expr) {
307 match expr {
308 Expr::Match(match_expr, _span) => {
309 self.check_expr(&match_expr.scrutinee);
311 for arm in &match_expr.arms {
312 if let Some(guard) = &arm.guard {
313 self.check_expr(guard);
314 }
315 self.check_expr(&arm.body);
316 }
317 }
318 Expr::BinaryOp { left, right, .. } => {
320 self.check_expr(left);
321 self.check_expr(right);
322 }
323 Expr::UnaryOp { operand, .. } => {
324 self.check_expr(operand);
325 }
326 Expr::Conditional {
327 condition,
328 then_expr,
329 else_expr,
330 ..
331 } => {
332 self.check_expr(condition);
333 self.check_expr(then_expr);
334 if let Some(else_e) = else_expr {
335 self.check_expr(else_e);
336 }
337 }
338 Expr::If(if_expr, _) => {
339 self.check_expr(&if_expr.condition);
340 self.check_expr(&if_expr.then_branch);
341 if let Some(else_branch) = &if_expr.else_branch {
342 self.check_expr(else_branch);
343 }
344 }
345 Expr::FunctionCall { args, .. } => {
346 for arg in args {
347 self.check_expr(arg);
348 }
349 }
350 Expr::QualifiedFunctionCall { args, .. } => {
351 for arg in args {
352 self.check_expr(arg);
353 }
354 }
355 Expr::MethodCall { receiver, args, .. } => {
356 self.check_expr(receiver);
357 for arg in args {
358 self.check_expr(arg);
359 }
360 }
361 Expr::Array(elems, _) => {
362 for elem in elems {
363 self.check_expr(elem);
364 }
365 }
366 Expr::PropertyAccess { object, .. } => {
367 self.check_expr(object);
368 }
369 Expr::IndexAccess {
370 object,
371 index,
372 end_index,
373 ..
374 } => {
375 self.check_expr(object);
376 self.check_expr(index);
377 if let Some(end) = end_index {
378 self.check_expr(end);
379 }
380 }
381 _ => {}
382 }
383 }
384
385 fn check_item(&mut self, item: &Item) {
391 match item {
392 Item::Function(func, span) => {
393 if func.return_type.is_some()
395 && !matches!(func.return_type.as_ref().unwrap(), TypeAnnotation::Void)
396 && !self.has_return_statement(&func.body)
397 {
398 let (line, col) = self.item_span_to_line_col(*span);
399 self.add_error(TypeError::MissingReturn(func.name.clone()), line, col);
400 }
401 }
402
403 Item::TypeAlias(alias, span) => {
404 if self.is_cyclic_type_alias(&alias.name, &alias.type_annotation) {
406 let (line, col) = self.item_span_to_line_col(*span);
407 self.add_error(TypeError::CyclicTypeAlias(alias.name.clone()), line, col);
408 }
409 }
410
411 Item::Interface(interface, span) => {
412 self.check_interface(interface, *span);
414 }
415
416 _ => {}
417 }
418 }
419
420 fn has_return_statement(&self, stmts: &[Statement]) -> bool {
422 for stmt in stmts {
423 match stmt {
424 Statement::Return(_, _) => return true,
425 Statement::If(if_stmt, _) => {
426 if let Some(else_body) = &if_stmt.else_body {
428 if self.has_return_statement(&if_stmt.then_body)
429 && self.has_return_statement(else_body)
430 {
431 return true;
432 }
433 }
434 }
435 Statement::While(while_loop, _) => {
436 if self.has_return_statement(&while_loop.body) {
437 return true;
439 }
440 }
441 Statement::For(for_loop, _) => {
442 if self.has_return_statement(&for_loop.body) {
443 return true;
445 }
446 }
447 _ => {}
448 }
449 }
450
451 false
452 }
453
454 fn is_cyclic_type_alias(&self, name: &str, ty: &TypeAnnotation) -> bool {
456 self.references_type(ty, name)
457 }
458
459 fn references_type(&self, ty: &TypeAnnotation, name: &str) -> bool {
461 match ty {
462 TypeAnnotation::Reference(ref_name) => ref_name == name,
463 TypeAnnotation::Array(elem) => self.references_type(elem, name),
464 TypeAnnotation::Tuple(elems) => {
465 elems.iter().any(|elem| self.references_type(elem, name))
466 }
467 TypeAnnotation::Object(fields) => fields
468 .iter()
469 .any(|field| self.references_type(&field.type_annotation, name)),
470 TypeAnnotation::Function { params, returns } => {
471 params
472 .iter()
473 .any(|param| self.references_type(¶m.type_annotation, name))
474 || self.references_type(returns, name)
475 }
476 TypeAnnotation::Union(types) => types.iter().any(|ty| self.references_type(ty, name)),
477 TypeAnnotation::Generic { args, .. } => {
478 args.iter().any(|arg| self.references_type(arg, name))
479 }
480 _ => false,
481 }
482 }
483
484 fn check_interface(&mut self, interface: &shape_ast::ast::InterfaceDef, interface_span: Span) {
486 let mut seen_members = HashMap::new();
488
489 for (i, member) in interface.members.iter().enumerate() {
490 let member_name = match member {
491 shape_ast::ast::InterfaceMember::Property { name, .. } => name,
492 shape_ast::ast::InterfaceMember::Method { name, .. } => name,
493 shape_ast::ast::InterfaceMember::IndexSignature { .. } => continue,
494 };
495
496 if let Some(_prev_index) = seen_members.get(member_name) {
497 let (line, col) = self.item_span_to_line_col(interface_span);
498 self.add_error(
499 TypeError::InterfaceError(
500 interface.name.clone(),
501 format!("Duplicate member '{}'", member_name),
502 ),
503 line,
504 col,
505 );
506 } else {
507 seen_members.insert(member_name.clone(), i);
508 }
509 }
510 }
511
512 fn item_span_to_line_col(&self, span: Span) -> (usize, usize) {
513 self.span_to_line_col(span).unwrap_or((0, 0))
514 }
515
516 fn add_error(&mut self, error: TypeError, line: usize, column: usize) {
518 let mut err = TypeErrorWithLocation::new(error, line, column);
519
520 if let Some(filename) = &self.filename {
521 err = err.with_file(filename.clone());
522 }
523
524 if let Some(source) = &self.source {
525 if let Some(source_line) = source.lines().nth(line.saturating_sub(1)) {
527 err = err.with_source_line(source_line.to_string());
528 }
529 }
530
531 self.errors.push(err);
532 }
533
534 pub fn errors(&self) -> &[TypeErrorWithLocation] {
536 &self.errors
537 }
538
539 pub fn format_errors(&self) -> String {
541 self.errors
542 .iter()
543 .map(|err| err.format_with_source())
544 .collect::<Vec<_>>()
545 .join("\n")
546 }
547}
548
549pub fn analyze_program(
551 program: &Program,
552 source: Option<&str>,
553 filename: Option<&str>,
554 known_bindings: Option<&[String]>,
555) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
556 analyze_program_with_mode(
557 program,
558 source,
559 filename,
560 known_bindings,
561 TypeAnalysisMode::FailFast,
562 )
563}
564
565pub fn analyze_program_with_mode(
567 program: &Program,
568 source: Option<&str>,
569 filename: Option<&str>,
570 known_bindings: Option<&[String]>,
571 analysis_mode: TypeAnalysisMode,
572) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
573 let mut checker = TypeChecker::new();
574 if let Some(src) = source {
575 checker = checker.with_source(src.to_string());
576 }
577 if let Some(file) = filename {
578 checker = checker.with_filename(file.to_string());
579 }
580 if let Some(names) = known_bindings {
581 checker = checker.with_known_bindings(names);
582 }
583 checker = checker.with_analysis_mode(analysis_mode);
584 checker.check_program(program)
585}
586
587#[derive(Debug)]
589pub struct TypeCheckResult {
590 pub types: HashMap<String, Type>,
592 pub semantic_types: HashMap<String, SemanticType>,
594 pub warnings: Vec<TypeWarning>,
596}
597
598impl TypeCheckResult {
599 pub fn get_semantic_type(&self, name: &str) -> Option<&SemanticType> {
601 self.semantic_types.get(name)
602 }
603
604 pub fn fallible_functions(&self) -> Vec<&str> {
606 self.semantic_types
607 .iter()
608 .filter_map(|(name, ty)| {
609 if let SemanticType::Function(sig) = ty {
610 if sig.return_type.is_result() {
611 return Some(name.as_str());
612 }
613 }
614 None
615 })
616 .collect()
617 }
618}
619
620#[derive(Debug)]
622pub struct TypeWarning {
623 pub message: String,
624 pub line: usize,
625 pub column: usize,
626}
627
628pub fn type_of_expr(expr: &Expr, _env: &TypeEnvironment) -> TypeResult<Type> {
630 let mut engine = TypeInferenceEngine::new();
631 engine.infer_expr(expr)
632}
633
634pub fn quick_check(source: &str) -> Result<TypeCheckResult, String> {
636 use shape_ast::parser::parse_program;
637
638 let program = parse_program(source).map_err(|e| format!("Parse error: {}", e))?;
639
640 let mut checker = TypeChecker::new().with_source(source.to_string());
641
642 checker.check_program(&program).map_err(|errors| {
643 errors
644 .iter()
645 .map(|e| e.format_with_source())
646 .collect::<Vec<_>>()
647 .join("\n")
648 })
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654
655 #[test]
656 fn test_exhaustiveness_integration_non_exhaustive_match_produces_error() {
657 let source = r#"
660 enum Status { Active, Inactive, Pending }
661
662 function check(s: Status) {
663 return match s {
664 Status::Active => "yes"
665 };
666 }
667 "#;
668
669 let result = quick_check(source);
670
671 assert!(
674 result.is_err(),
675 "Expected error for non-exhaustive match, got: {:?}",
676 result
677 );
678 let err = result.unwrap_err();
679 assert!(
680 err.contains("NonExhaustive")
681 || err.contains("non-exhaustive")
682 || err.contains("missing"),
683 "Expected non-exhaustive match error, got: {}",
684 err
685 );
686 }
687
688 #[test]
689 fn test_exhaustiveness_integration_exhaustive_match_succeeds() {
690 let source = r#"
692 enum Status { Active, Inactive }
693
694 function check(s: Status) {
695 return match s {
696 Status::Active => "yes",
697 Status::Inactive => "no"
698 };
699 }
700 "#;
701
702 let result = quick_check(source);
703
704 if let Err(err) = &result {
707 assert!(
708 !err.contains("NonExhaustive") && !err.contains("non-exhaustive"),
709 "Should not have non-exhaustive error for exhaustive match, got: {}",
710 err
711 );
712 }
713 }
714
715 #[test]
716 fn test_exhaustiveness_integration_wildcard_makes_exhaustive() {
717 let source = r#"
719 enum Status { Active, Inactive, Pending }
720
721 function check(s: Status) {
722 return match s {
723 Status::Active => "yes",
724 _ => "other"
725 };
726 }
727 "#;
728
729 let result = quick_check(source);
730
731 if let Err(err) = &result {
733 assert!(
734 !err.contains("NonExhaustive") && !err.contains("non-exhaustive"),
735 "Wildcard should make match exhaustive, got: {}",
736 err
737 );
738 }
739 }
740
741 #[test]
742 fn test_undefined_variable_reports_identifier_position() {
743 use shape_ast::parser::parse_program;
744
745 let source = r#"
746let x = 1
747let y = duckdb.connect("duckdb://analytics.db")
748"#;
749
750 let program = parse_program(source).expect("program should parse");
751 let result = analyze_program(&program, Some(source), None, None);
752 let errors = result.expect_err("undefined variable should fail analysis");
753 let undef = errors
754 .iter()
755 .find(|e| matches!(&e.error, TypeError::UndefinedVariable(name) if name == "duckdb"))
756 .expect("missing undefined-variable error for duckdb");
757
758 assert_eq!(undef.line, 3);
759 assert_eq!(undef.column, 9);
760 }
761
762 #[test]
763 fn test_known_bindings_allow_extension_namespace_in_type_analysis() {
764 use shape_ast::parser::parse_program;
765
766 let source = r#"let conn = duckdb.connect("duckdb://analytics.db")"#;
767 let program = parse_program(source).expect("program should parse");
768 let known = vec!["duckdb".to_string()];
769
770 let result = analyze_program(&program, Some(source), None, Some(&known));
771 assert!(
772 result.is_ok(),
773 "known extension namespaces should not fail type analysis: {:?}",
774 result.err()
775 );
776 }
777}