Skip to main content

shape_runtime/type_system/
exhaustiveness.rs

1//! Exhaustiveness Checking for Match Expressions
2//!
3//! Implements compile-time verification that all enum variants are covered
4//! in match expressions.
5//!
6//! Rules:
7//! 1. Patterns with `where` guards do NOT contribute to exhaustiveness coverage
8//! 2. Unguarded `_` (wildcard) or identifier pattern makes match exhaustive
9//! 3. For enums: Uncovered = AllVariants - CoveredVariants
10
11use super::errors::{TypeError, TypeResult};
12use super::semantic::SemanticType;
13use super::types::Type;
14
15// EnumVariant is used in tests
16#[cfg(test)]
17use super::semantic::EnumVariant;
18use super::types::annotation_to_semantic;
19use shape_ast::ast::TypeAnnotation;
20use shape_ast::ast::{MatchArm, MatchExpr, Pattern};
21use std::collections::HashSet;
22
23/// Result of exhaustiveness checking
24#[derive(Debug, Clone, PartialEq)]
25pub enum ExhaustivenessResult {
26    /// Match is exhaustive (all cases covered)
27    Exhaustive,
28    /// Match is non-exhaustive (missing variants)
29    NonExhaustive {
30        enum_name: String,
31        missing_variants: Vec<String>,
32    },
33    /// Match is trivially exhaustive (has wildcard or catch-all pattern)
34    TriviallyExhaustive,
35    /// Scrutinee is not an enum type (exhaustiveness not applicable)
36    NotApplicable,
37}
38
39impl ExhaustivenessResult {
40    /// Returns true if the match is exhaustive
41    pub fn is_exhaustive(&self) -> bool {
42        matches!(
43            self,
44            ExhaustivenessResult::Exhaustive
45                | ExhaustivenessResult::TriviallyExhaustive
46                | ExhaustivenessResult::NotApplicable
47        )
48    }
49
50    /// Convert to a TypeError if non-exhaustive
51    pub fn to_error(&self) -> Option<TypeError> {
52        match self {
53            ExhaustivenessResult::NonExhaustive {
54                enum_name,
55                missing_variants,
56            } => Some(TypeError::NonExhaustiveMatch {
57                enum_name: enum_name.clone(),
58                missing_variants: missing_variants.clone(),
59            }),
60            _ => None,
61        }
62    }
63}
64
65/// Check exhaustiveness of a match expression
66pub fn check_exhaustiveness(
67    match_expr: &MatchExpr,
68    scrutinee_type: &SemanticType,
69) -> ExhaustivenessResult {
70    // Only check enums for now - other types are either trivially exhaustive
71    // or require more sophisticated pattern analysis
72    let (enum_name, variants) = match scrutinee_type {
73        SemanticType::Enum { name, variants, .. } => (name.clone(), variants.clone()),
74        // For non-enum types, check if there's a wildcard pattern
75        _ => {
76            if has_unguarded_catch_all(&match_expr.arms) {
77                return ExhaustivenessResult::TriviallyExhaustive;
78            }
79            return ExhaustivenessResult::NotApplicable;
80        }
81    };
82
83    // Collect all covered variants from unguarded patterns
84    let covered = collect_covered_variants(&match_expr.arms, &enum_name);
85
86    // Check for trivial exhaustiveness (wildcard or catch-all)
87    if has_unguarded_catch_all(&match_expr.arms) {
88        return ExhaustivenessResult::TriviallyExhaustive;
89    }
90
91    // Compute missing variants
92    let all_variants: HashSet<_> = variants.iter().map(|v| v.name.clone()).collect();
93    let missing: Vec<_> = all_variants.difference(&covered).cloned().collect();
94
95    if missing.is_empty() {
96        ExhaustivenessResult::Exhaustive
97    } else {
98        ExhaustivenessResult::NonExhaustive {
99            enum_name,
100            missing_variants: missing,
101        }
102    }
103}
104
105/// Check exhaustiveness from inference-level type information.
106///
107/// This supports enum exhaustiveness and closed union exhaustiveness.
108pub fn check_exhaustiveness_for_type(
109    match_expr: &MatchExpr,
110    scrutinee_type: &Type,
111) -> ExhaustivenessResult {
112    if let Some(TypeAnnotation::Union(variants)) = scrutinee_type.to_annotation() {
113        return check_union_exhaustiveness(match_expr, &variants);
114    }
115
116    if let Some(semantic_type) = scrutinee_type.to_semantic() {
117        return check_exhaustiveness(match_expr, &semantic_type);
118    }
119
120    if has_unguarded_catch_all(&match_expr.arms) {
121        ExhaustivenessResult::TriviallyExhaustive
122    } else {
123        // Type inference could not resolve the scrutinee type, so exhaustiveness
124        // checking is skipped. This can mask missing match arms at compile time.
125        tracing::debug!(
126            "exhaustiveness check skipped: scrutinee type {:?} could not be resolved",
127            scrutinee_type
128        );
129        ExhaustivenessResult::NotApplicable
130    }
131}
132
133fn check_union_exhaustiveness(
134    match_expr: &MatchExpr,
135    union_variants: &[TypeAnnotation],
136) -> ExhaustivenessResult {
137    if has_unguarded_catch_all(&match_expr.arms) {
138        return ExhaustivenessResult::TriviallyExhaustive;
139    }
140
141    let covered_types = collect_covered_union_types(&match_expr.arms);
142    let missing: Vec<TypeAnnotation> = union_variants
143        .iter()
144        .filter(|variant| {
145            !covered_types
146                .iter()
147                .any(|covered| types_match(covered, variant))
148        })
149        .cloned()
150        .collect();
151
152    if missing.is_empty() {
153        ExhaustivenessResult::Exhaustive
154    } else {
155        ExhaustivenessResult::NonExhaustive {
156            enum_name: format_union_type_name(union_variants),
157            missing_variants: missing.iter().map(format_type_annotation).collect(),
158        }
159    }
160}
161
162fn collect_covered_union_types(arms: &[MatchArm]) -> Vec<TypeAnnotation> {
163    let mut covered = Vec::new();
164
165    for arm in arms {
166        // Guarded arms do not contribute to exhaustiveness
167        if arm.guard.is_some() {
168            continue;
169        }
170
171        if let Pattern::Typed {
172            type_annotation, ..
173        } = &arm.pattern
174        {
175            for ty in flatten_union_annotation(type_annotation) {
176                if !covered.iter().any(|existing| types_match(existing, ty)) {
177                    covered.push(ty.clone());
178                }
179            }
180        }
181    }
182
183    covered
184}
185
186fn flatten_union_annotation(ann: &TypeAnnotation) -> Vec<&TypeAnnotation> {
187    match ann {
188        TypeAnnotation::Union(types) => {
189            let mut out = Vec::new();
190            for ty in types {
191                out.extend(flatten_union_annotation(ty));
192            }
193            out
194        }
195        _ => vec![ann],
196    }
197}
198
199fn types_match(a: &TypeAnnotation, b: &TypeAnnotation) -> bool {
200    annotation_to_semantic(a) == annotation_to_semantic(b)
201}
202
203fn format_union_type_name(types: &[TypeAnnotation]) -> String {
204    types
205        .iter()
206        .map(format_type_annotation)
207        .collect::<Vec<_>>()
208        .join(" | ")
209}
210
211fn format_type_annotation(ann: &TypeAnnotation) -> String {
212    match ann {
213        TypeAnnotation::Basic(name) => name.clone(),
214        TypeAnnotation::Reference(name) => name.to_string(),
215        TypeAnnotation::Array(inner) => format!("Vec<{}>", format_type_annotation(inner)),
216        TypeAnnotation::Tuple(elems) => format!(
217            "[{}]",
218            elems
219                .iter()
220                .map(format_type_annotation)
221                .collect::<Vec<_>>()
222                .join(", ")
223        ),
224        TypeAnnotation::Object(_) => "object".to_string(),
225        TypeAnnotation::Function { .. } => "function".to_string(),
226        TypeAnnotation::Union(types) => types
227            .iter()
228            .map(format_type_annotation)
229            .collect::<Vec<_>>()
230            .join(" | "),
231        TypeAnnotation::Intersection(types) => types
232            .iter()
233            .map(format_type_annotation)
234            .collect::<Vec<_>>()
235            .join(" + "),
236        TypeAnnotation::Generic { name, args } => {
237            if args.is_empty() {
238                name.to_string()
239            } else {
240                format!(
241                    "{}<{}>",
242                    name,
243                    args.iter()
244                        .map(format_type_annotation)
245                        .collect::<Vec<_>>()
246                        .join(", ")
247                )
248            }
249        }
250        TypeAnnotation::Void => "void".to_string(),
251        TypeAnnotation::Never => "never".to_string(),
252        TypeAnnotation::Null => "None".to_string(),
253        TypeAnnotation::Undefined => "undefined".to_string(),
254        TypeAnnotation::Dyn(traits) => format!("dyn {}", traits.join(" + ")),
255    }
256}
257
258/// Check if the match has an unguarded catch-all pattern
259fn has_unguarded_catch_all(arms: &[MatchArm]) -> bool {
260    arms.iter().any(|arm| {
261        // Only unguarded patterns count
262        if arm.guard.is_some() {
263            return false;
264        }
265        is_catch_all_pattern(&arm.pattern)
266    })
267}
268
269/// Check if a pattern is a catch-all (matches everything)
270fn is_catch_all_pattern(pattern: &Pattern) -> bool {
271    match pattern {
272        // Wildcard matches everything
273        Pattern::Wildcard => true,
274        // Identifier without guard matches everything
275        Pattern::Identifier(_) => true,
276        // Other patterns are not catch-all
277        _ => false,
278    }
279}
280
281/// Collect all variant names covered by unguarded constructor patterns
282fn collect_covered_variants(arms: &[MatchArm], enum_name: &str) -> HashSet<String> {
283    let mut covered = HashSet::new();
284
285    for arm in arms {
286        // Patterns with guards do NOT contribute to exhaustiveness
287        if arm.guard.is_some() {
288            continue;
289        }
290
291        if let Some(variant_name) = extract_variant_name(&arm.pattern, enum_name) {
292            covered.insert(variant_name);
293        }
294    }
295
296    covered
297}
298
299/// Extract the variant name from a constructor pattern
300fn extract_variant_name(pattern: &Pattern, expected_enum: &str) -> Option<String> {
301    match pattern {
302        Pattern::Constructor {
303            enum_name, variant, ..
304        } => {
305            // Check if this matches the expected enum
306            match enum_name {
307                Some(name) if name == expected_enum => Some(variant.clone()),
308                None => Some(variant.clone()), // Allow unqualified variant names
309                _ => None,
310            }
311        }
312        _ => None,
313    }
314}
315
316/// Check a match expression and return an error if non-exhaustive
317pub fn require_exhaustive(match_expr: &MatchExpr, scrutinee_type: &SemanticType) -> TypeResult<()> {
318    let result = check_exhaustiveness(match_expr, scrutinee_type);
319    match result.to_error() {
320        Some(err) => Err(err),
321        None => Ok(()),
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use shape_ast::ast::{Expr, Literal, Span};
329
330    fn make_span() -> Span {
331        Span { start: 0, end: 0 }
332    }
333
334    fn make_enum_type(name: &str, variants: &[&str]) -> SemanticType {
335        SemanticType::Enum {
336            name: name.to_string(),
337            variants: variants
338                .iter()
339                .map(|v| EnumVariant {
340                    name: v.to_string(),
341                    payload: None,
342                })
343                .collect(),
344            type_params: vec![],
345        }
346    }
347
348    fn make_match_arm(pattern: Pattern, guard: Option<Expr>, body: Expr) -> MatchArm {
349        MatchArm {
350            pattern,
351            guard: guard.map(Box::new),
352            body: Box::new(body),
353            pattern_span: None,
354        }
355    }
356
357    fn make_constructor_pattern(enum_name: Option<&str>, variant: &str) -> Pattern {
358        Pattern::Constructor {
359            enum_name: enum_name.map(|s| s.into()),
360            variant: variant.to_string(),
361            fields: shape_ast::ast::PatternConstructorFields::Unit,
362        }
363    }
364
365    fn make_string_expr(s: &str) -> Expr {
366        Expr::Literal(Literal::String(s.to_string()), make_span())
367    }
368
369    #[test]
370    fn test_exhaustive_match_all_variants() {
371        let status_type = make_enum_type("Status", &["Active", "Inactive"]);
372        let match_expr = MatchExpr {
373            scrutinee: Box::new(Expr::Identifier("status".to_string(), make_span())),
374            arms: vec![
375                make_match_arm(
376                    make_constructor_pattern(Some("Status"), "Active"),
377                    None,
378                    make_string_expr("yes"),
379                ),
380                make_match_arm(
381                    make_constructor_pattern(Some("Status"), "Inactive"),
382                    None,
383                    make_string_expr("no"),
384                ),
385            ],
386        };
387
388        let result = check_exhaustiveness(&match_expr, &status_type);
389        assert_eq!(result, ExhaustivenessResult::Exhaustive);
390    }
391
392    #[test]
393    fn test_non_exhaustive_missing_variant() {
394        let status_type = make_enum_type("Status", &["Active", "Inactive"]);
395        let match_expr = MatchExpr {
396            scrutinee: Box::new(Expr::Identifier("status".to_string(), make_span())),
397            arms: vec![make_match_arm(
398                make_constructor_pattern(Some("Status"), "Active"),
399                None,
400                make_string_expr("yes"),
401            )],
402        };
403
404        let result = check_exhaustiveness(&match_expr, &status_type);
405        match result {
406            ExhaustivenessResult::NonExhaustive {
407                enum_name,
408                missing_variants,
409            } => {
410                assert_eq!(enum_name, "Status");
411                assert_eq!(missing_variants, vec!["Inactive"]);
412            }
413            _ => panic!("Expected NonExhaustive"),
414        }
415    }
416
417    #[test]
418    fn test_exhaustive_with_wildcard() {
419        let status_type = make_enum_type("Status", &["Active", "Inactive", "Pending"]);
420        let match_expr = MatchExpr {
421            scrutinee: Box::new(Expr::Identifier("status".to_string(), make_span())),
422            arms: vec![
423                make_match_arm(
424                    make_constructor_pattern(Some("Status"), "Active"),
425                    None,
426                    make_string_expr("yes"),
427                ),
428                make_match_arm(Pattern::Wildcard, None, make_string_expr("no")),
429            ],
430        };
431
432        let result = check_exhaustiveness(&match_expr, &status_type);
433        assert_eq!(result, ExhaustivenessResult::TriviallyExhaustive);
434    }
435
436    #[test]
437    fn test_guarded_pattern_does_not_count() {
438        let status_type = make_enum_type("Status", &["Active", "Inactive"]);
439        // Pattern with guard should not contribute to exhaustiveness
440        let match_expr = MatchExpr {
441            scrutinee: Box::new(Expr::Identifier("status".to_string(), make_span())),
442            arms: vec![
443                make_match_arm(
444                    make_constructor_pattern(Some("Status"), "Active"),
445                    Some(Expr::Literal(Literal::Bool(true), make_span())),
446                    make_string_expr("yes"),
447                ),
448                make_match_arm(
449                    make_constructor_pattern(Some("Status"), "Inactive"),
450                    None,
451                    make_string_expr("no"),
452                ),
453            ],
454        };
455
456        let result = check_exhaustiveness(&match_expr, &status_type);
457        match result {
458            ExhaustivenessResult::NonExhaustive {
459                missing_variants, ..
460            } => {
461                assert!(missing_variants.contains(&"Active".to_string()));
462            }
463            _ => panic!("Expected NonExhaustive because guarded Active doesn't count"),
464        }
465    }
466
467    #[test]
468    fn test_non_enum_with_wildcard_is_exhaustive() {
469        let number_type = SemanticType::Number;
470        let match_expr = MatchExpr {
471            scrutinee: Box::new(Expr::Identifier("x".to_string(), make_span())),
472            arms: vec![
473                make_match_arm(
474                    Pattern::Literal(Literal::Number(1.0)),
475                    None,
476                    make_string_expr("one"),
477                ),
478                make_match_arm(Pattern::Wildcard, None, make_string_expr("other")),
479            ],
480        };
481
482        let result = check_exhaustiveness(&match_expr, &number_type);
483        assert_eq!(result, ExhaustivenessResult::TriviallyExhaustive);
484    }
485
486    #[test]
487    fn test_union_typed_patterns_are_exhaustive() {
488        let union_type = Type::Concrete(TypeAnnotation::Union(vec![
489            TypeAnnotation::Basic("int".to_string()),
490            TypeAnnotation::Basic("string".to_string()),
491        ]));
492        let match_expr = MatchExpr {
493            scrutinee: Box::new(Expr::Identifier("x".to_string(), make_span())),
494            arms: vec![
495                make_match_arm(
496                    Pattern::Typed {
497                        name: "n".to_string(),
498                        type_annotation: TypeAnnotation::Basic("int".to_string()),
499                    },
500                    None,
501                    make_string_expr("int"),
502                ),
503                make_match_arm(
504                    Pattern::Typed {
505                        name: "s".to_string(),
506                        type_annotation: TypeAnnotation::Basic("string".to_string()),
507                    },
508                    None,
509                    make_string_expr("string"),
510                ),
511            ],
512        };
513
514        let result = check_exhaustiveness_for_type(&match_expr, &union_type);
515        assert_eq!(result, ExhaustivenessResult::Exhaustive);
516    }
517
518    #[test]
519    fn test_union_typed_patterns_missing_variant_reports_non_exhaustive() {
520        let union_type = Type::Concrete(TypeAnnotation::Union(vec![
521            TypeAnnotation::Basic("int".to_string()),
522            TypeAnnotation::Basic("string".to_string()),
523        ]));
524        let match_expr = MatchExpr {
525            scrutinee: Box::new(Expr::Identifier("x".to_string(), make_span())),
526            arms: vec![make_match_arm(
527                Pattern::Typed {
528                    name: "n".to_string(),
529                    type_annotation: TypeAnnotation::Basic("int".to_string()),
530                },
531                None,
532                make_string_expr("int"),
533            )],
534        };
535
536        let result = check_exhaustiveness_for_type(&match_expr, &union_type);
537        match result {
538            ExhaustivenessResult::NonExhaustive {
539                enum_name,
540                missing_variants,
541            } => {
542                assert_eq!(enum_name, "int | string");
543                assert_eq!(missing_variants, vec!["string"]);
544            }
545            other => panic!("Expected NonExhaustive, got {:?}", other),
546        }
547    }
548}