1use super::errors::{TypeError, TypeResult};
12use super::semantic::SemanticType;
13use super::types::Type;
14
15#[cfg(test)]
17use super::semantic::EnumVariant;
18use super::types::annotation_to_semantic;
19use shape_ast::ast::TypeAnnotation;
20use shape_ast::ast::{MatchArm, MatchExpr, Pattern};
21use std::collections::HashSet;
22
23#[derive(Debug, Clone, PartialEq)]
25pub enum ExhaustivenessResult {
26 Exhaustive,
28 NonExhaustive {
30 enum_name: String,
31 missing_variants: Vec<String>,
32 },
33 TriviallyExhaustive,
35 NotApplicable,
37}
38
39impl ExhaustivenessResult {
40 pub fn is_exhaustive(&self) -> bool {
42 matches!(
43 self,
44 ExhaustivenessResult::Exhaustive
45 | ExhaustivenessResult::TriviallyExhaustive
46 | ExhaustivenessResult::NotApplicable
47 )
48 }
49
50 pub fn to_error(&self) -> Option<TypeError> {
52 match self {
53 ExhaustivenessResult::NonExhaustive {
54 enum_name,
55 missing_variants,
56 } => Some(TypeError::NonExhaustiveMatch {
57 enum_name: enum_name.clone(),
58 missing_variants: missing_variants.clone(),
59 }),
60 _ => None,
61 }
62 }
63}
64
65pub fn check_exhaustiveness(
67 match_expr: &MatchExpr,
68 scrutinee_type: &SemanticType,
69) -> ExhaustivenessResult {
70 let (enum_name, variants) = match scrutinee_type {
73 SemanticType::Enum { name, variants, .. } => (name.clone(), variants.clone()),
74 _ => {
76 if has_unguarded_catch_all(&match_expr.arms) {
77 return ExhaustivenessResult::TriviallyExhaustive;
78 }
79 return ExhaustivenessResult::NotApplicable;
80 }
81 };
82
83 let covered = collect_covered_variants(&match_expr.arms, &enum_name);
85
86 if has_unguarded_catch_all(&match_expr.arms) {
88 return ExhaustivenessResult::TriviallyExhaustive;
89 }
90
91 let all_variants: HashSet<_> = variants.iter().map(|v| v.name.clone()).collect();
93 let missing: Vec<_> = all_variants.difference(&covered).cloned().collect();
94
95 if missing.is_empty() {
96 ExhaustivenessResult::Exhaustive
97 } else {
98 ExhaustivenessResult::NonExhaustive {
99 enum_name,
100 missing_variants: missing,
101 }
102 }
103}
104
105pub fn check_exhaustiveness_for_type(
109 match_expr: &MatchExpr,
110 scrutinee_type: &Type,
111) -> ExhaustivenessResult {
112 if let Some(TypeAnnotation::Union(variants)) = scrutinee_type.to_annotation() {
113 return check_union_exhaustiveness(match_expr, &variants);
114 }
115
116 if let Some(semantic_type) = scrutinee_type.to_semantic() {
117 return check_exhaustiveness(match_expr, &semantic_type);
118 }
119
120 if has_unguarded_catch_all(&match_expr.arms) {
121 ExhaustivenessResult::TriviallyExhaustive
122 } else {
123 tracing::debug!(
126 "exhaustiveness check skipped: scrutinee type {:?} could not be resolved",
127 scrutinee_type
128 );
129 ExhaustivenessResult::NotApplicable
130 }
131}
132
133fn check_union_exhaustiveness(
134 match_expr: &MatchExpr,
135 union_variants: &[TypeAnnotation],
136) -> ExhaustivenessResult {
137 if has_unguarded_catch_all(&match_expr.arms) {
138 return ExhaustivenessResult::TriviallyExhaustive;
139 }
140
141 let covered_types = collect_covered_union_types(&match_expr.arms);
142 let missing: Vec<TypeAnnotation> = union_variants
143 .iter()
144 .filter(|variant| {
145 !covered_types
146 .iter()
147 .any(|covered| types_match(covered, variant))
148 })
149 .cloned()
150 .collect();
151
152 if missing.is_empty() {
153 ExhaustivenessResult::Exhaustive
154 } else {
155 ExhaustivenessResult::NonExhaustive {
156 enum_name: format_union_type_name(union_variants),
157 missing_variants: missing.iter().map(format_type_annotation).collect(),
158 }
159 }
160}
161
162fn collect_covered_union_types(arms: &[MatchArm]) -> Vec<TypeAnnotation> {
163 let mut covered = Vec::new();
164
165 for arm in arms {
166 if arm.guard.is_some() {
168 continue;
169 }
170
171 if let Pattern::Typed {
172 type_annotation, ..
173 } = &arm.pattern
174 {
175 for ty in flatten_union_annotation(type_annotation) {
176 if !covered.iter().any(|existing| types_match(existing, ty)) {
177 covered.push(ty.clone());
178 }
179 }
180 }
181 }
182
183 covered
184}
185
186fn flatten_union_annotation(ann: &TypeAnnotation) -> Vec<&TypeAnnotation> {
187 match ann {
188 TypeAnnotation::Union(types) => {
189 let mut out = Vec::new();
190 for ty in types {
191 out.extend(flatten_union_annotation(ty));
192 }
193 out
194 }
195 _ => vec![ann],
196 }
197}
198
199fn types_match(a: &TypeAnnotation, b: &TypeAnnotation) -> bool {
200 annotation_to_semantic(a) == annotation_to_semantic(b)
201}
202
203fn format_union_type_name(types: &[TypeAnnotation]) -> String {
204 types
205 .iter()
206 .map(format_type_annotation)
207 .collect::<Vec<_>>()
208 .join(" | ")
209}
210
211fn format_type_annotation(ann: &TypeAnnotation) -> String {
212 match ann {
213 TypeAnnotation::Basic(name) => name.clone(),
214 TypeAnnotation::Reference(name) => name.to_string(),
215 TypeAnnotation::Array(inner) => format!("Vec<{}>", format_type_annotation(inner)),
216 TypeAnnotation::Tuple(elems) => format!(
217 "[{}]",
218 elems
219 .iter()
220 .map(format_type_annotation)
221 .collect::<Vec<_>>()
222 .join(", ")
223 ),
224 TypeAnnotation::Object(_) => "object".to_string(),
225 TypeAnnotation::Function { .. } => "function".to_string(),
226 TypeAnnotation::Union(types) => types
227 .iter()
228 .map(format_type_annotation)
229 .collect::<Vec<_>>()
230 .join(" | "),
231 TypeAnnotation::Intersection(types) => types
232 .iter()
233 .map(format_type_annotation)
234 .collect::<Vec<_>>()
235 .join(" + "),
236 TypeAnnotation::Generic { name, args } => {
237 if args.is_empty() {
238 name.to_string()
239 } else {
240 format!(
241 "{}<{}>",
242 name,
243 args.iter()
244 .map(format_type_annotation)
245 .collect::<Vec<_>>()
246 .join(", ")
247 )
248 }
249 }
250 TypeAnnotation::Void => "void".to_string(),
251 TypeAnnotation::Never => "never".to_string(),
252 TypeAnnotation::Null => "None".to_string(),
253 TypeAnnotation::Undefined => "undefined".to_string(),
254 TypeAnnotation::Dyn(traits) => format!("dyn {}", traits.join(" + ")),
255 }
256}
257
258fn has_unguarded_catch_all(arms: &[MatchArm]) -> bool {
260 arms.iter().any(|arm| {
261 if arm.guard.is_some() {
263 return false;
264 }
265 is_catch_all_pattern(&arm.pattern)
266 })
267}
268
269fn is_catch_all_pattern(pattern: &Pattern) -> bool {
271 match pattern {
272 Pattern::Wildcard => true,
274 Pattern::Identifier(_) => true,
276 _ => false,
278 }
279}
280
281fn collect_covered_variants(arms: &[MatchArm], enum_name: &str) -> HashSet<String> {
283 let mut covered = HashSet::new();
284
285 for arm in arms {
286 if arm.guard.is_some() {
288 continue;
289 }
290
291 if let Some(variant_name) = extract_variant_name(&arm.pattern, enum_name) {
292 covered.insert(variant_name);
293 }
294 }
295
296 covered
297}
298
299fn extract_variant_name(pattern: &Pattern, expected_enum: &str) -> Option<String> {
301 match pattern {
302 Pattern::Constructor {
303 enum_name, variant, ..
304 } => {
305 match enum_name {
307 Some(name) if name == expected_enum => Some(variant.clone()),
308 None => Some(variant.clone()), _ => None,
310 }
311 }
312 _ => None,
313 }
314}
315
316pub fn require_exhaustive(match_expr: &MatchExpr, scrutinee_type: &SemanticType) -> TypeResult<()> {
318 let result = check_exhaustiveness(match_expr, scrutinee_type);
319 match result.to_error() {
320 Some(err) => Err(err),
321 None => Ok(()),
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use shape_ast::ast::{Expr, Literal, Span};
329
330 fn make_span() -> Span {
331 Span { start: 0, end: 0 }
332 }
333
334 fn make_enum_type(name: &str, variants: &[&str]) -> SemanticType {
335 SemanticType::Enum {
336 name: name.to_string(),
337 variants: variants
338 .iter()
339 .map(|v| EnumVariant {
340 name: v.to_string(),
341 payload: None,
342 })
343 .collect(),
344 type_params: vec![],
345 }
346 }
347
348 fn make_match_arm(pattern: Pattern, guard: Option<Expr>, body: Expr) -> MatchArm {
349 MatchArm {
350 pattern,
351 guard: guard.map(Box::new),
352 body: Box::new(body),
353 pattern_span: None,
354 }
355 }
356
357 fn make_constructor_pattern(enum_name: Option<&str>, variant: &str) -> Pattern {
358 Pattern::Constructor {
359 enum_name: enum_name.map(|s| s.into()),
360 variant: variant.to_string(),
361 fields: shape_ast::ast::PatternConstructorFields::Unit,
362 }
363 }
364
365 fn make_string_expr(s: &str) -> Expr {
366 Expr::Literal(Literal::String(s.to_string()), make_span())
367 }
368
369 #[test]
370 fn test_exhaustive_match_all_variants() {
371 let status_type = make_enum_type("Status", &["Active", "Inactive"]);
372 let match_expr = MatchExpr {
373 scrutinee: Box::new(Expr::Identifier("status".to_string(), make_span())),
374 arms: vec![
375 make_match_arm(
376 make_constructor_pattern(Some("Status"), "Active"),
377 None,
378 make_string_expr("yes"),
379 ),
380 make_match_arm(
381 make_constructor_pattern(Some("Status"), "Inactive"),
382 None,
383 make_string_expr("no"),
384 ),
385 ],
386 };
387
388 let result = check_exhaustiveness(&match_expr, &status_type);
389 assert_eq!(result, ExhaustivenessResult::Exhaustive);
390 }
391
392 #[test]
393 fn test_non_exhaustive_missing_variant() {
394 let status_type = make_enum_type("Status", &["Active", "Inactive"]);
395 let match_expr = MatchExpr {
396 scrutinee: Box::new(Expr::Identifier("status".to_string(), make_span())),
397 arms: vec![make_match_arm(
398 make_constructor_pattern(Some("Status"), "Active"),
399 None,
400 make_string_expr("yes"),
401 )],
402 };
403
404 let result = check_exhaustiveness(&match_expr, &status_type);
405 match result {
406 ExhaustivenessResult::NonExhaustive {
407 enum_name,
408 missing_variants,
409 } => {
410 assert_eq!(enum_name, "Status");
411 assert_eq!(missing_variants, vec!["Inactive"]);
412 }
413 _ => panic!("Expected NonExhaustive"),
414 }
415 }
416
417 #[test]
418 fn test_exhaustive_with_wildcard() {
419 let status_type = make_enum_type("Status", &["Active", "Inactive", "Pending"]);
420 let match_expr = MatchExpr {
421 scrutinee: Box::new(Expr::Identifier("status".to_string(), make_span())),
422 arms: vec![
423 make_match_arm(
424 make_constructor_pattern(Some("Status"), "Active"),
425 None,
426 make_string_expr("yes"),
427 ),
428 make_match_arm(Pattern::Wildcard, None, make_string_expr("no")),
429 ],
430 };
431
432 let result = check_exhaustiveness(&match_expr, &status_type);
433 assert_eq!(result, ExhaustivenessResult::TriviallyExhaustive);
434 }
435
436 #[test]
437 fn test_guarded_pattern_does_not_count() {
438 let status_type = make_enum_type("Status", &["Active", "Inactive"]);
439 let match_expr = MatchExpr {
441 scrutinee: Box::new(Expr::Identifier("status".to_string(), make_span())),
442 arms: vec![
443 make_match_arm(
444 make_constructor_pattern(Some("Status"), "Active"),
445 Some(Expr::Literal(Literal::Bool(true), make_span())),
446 make_string_expr("yes"),
447 ),
448 make_match_arm(
449 make_constructor_pattern(Some("Status"), "Inactive"),
450 None,
451 make_string_expr("no"),
452 ),
453 ],
454 };
455
456 let result = check_exhaustiveness(&match_expr, &status_type);
457 match result {
458 ExhaustivenessResult::NonExhaustive {
459 missing_variants, ..
460 } => {
461 assert!(missing_variants.contains(&"Active".to_string()));
462 }
463 _ => panic!("Expected NonExhaustive because guarded Active doesn't count"),
464 }
465 }
466
467 #[test]
468 fn test_non_enum_with_wildcard_is_exhaustive() {
469 let number_type = SemanticType::Number;
470 let match_expr = MatchExpr {
471 scrutinee: Box::new(Expr::Identifier("x".to_string(), make_span())),
472 arms: vec![
473 make_match_arm(
474 Pattern::Literal(Literal::Number(1.0)),
475 None,
476 make_string_expr("one"),
477 ),
478 make_match_arm(Pattern::Wildcard, None, make_string_expr("other")),
479 ],
480 };
481
482 let result = check_exhaustiveness(&match_expr, &number_type);
483 assert_eq!(result, ExhaustivenessResult::TriviallyExhaustive);
484 }
485
486 #[test]
487 fn test_union_typed_patterns_are_exhaustive() {
488 let union_type = Type::Concrete(TypeAnnotation::Union(vec![
489 TypeAnnotation::Basic("int".to_string()),
490 TypeAnnotation::Basic("string".to_string()),
491 ]));
492 let match_expr = MatchExpr {
493 scrutinee: Box::new(Expr::Identifier("x".to_string(), make_span())),
494 arms: vec![
495 make_match_arm(
496 Pattern::Typed {
497 name: "n".to_string(),
498 type_annotation: TypeAnnotation::Basic("int".to_string()),
499 },
500 None,
501 make_string_expr("int"),
502 ),
503 make_match_arm(
504 Pattern::Typed {
505 name: "s".to_string(),
506 type_annotation: TypeAnnotation::Basic("string".to_string()),
507 },
508 None,
509 make_string_expr("string"),
510 ),
511 ],
512 };
513
514 let result = check_exhaustiveness_for_type(&match_expr, &union_type);
515 assert_eq!(result, ExhaustivenessResult::Exhaustive);
516 }
517
518 #[test]
519 fn test_union_typed_patterns_missing_variant_reports_non_exhaustive() {
520 let union_type = Type::Concrete(TypeAnnotation::Union(vec![
521 TypeAnnotation::Basic("int".to_string()),
522 TypeAnnotation::Basic("string".to_string()),
523 ]));
524 let match_expr = MatchExpr {
525 scrutinee: Box::new(Expr::Identifier("x".to_string(), make_span())),
526 arms: vec![make_match_arm(
527 Pattern::Typed {
528 name: "n".to_string(),
529 type_annotation: TypeAnnotation::Basic("int".to_string()),
530 },
531 None,
532 make_string_expr("int"),
533 )],
534 };
535
536 let result = check_exhaustiveness_for_type(&match_expr, &union_type);
537 match result {
538 ExhaustivenessResult::NonExhaustive {
539 enum_name,
540 missing_variants,
541 } => {
542 assert_eq!(enum_name, "int | string");
543 assert_eq!(missing_variants, vec!["string"]);
544 }
545 other => panic!("Expected NonExhaustive, got {:?}", other),
546 }
547 }
548}