Skip to main content

virtual_rust/interpreter/
pattern.rs

1//! Pattern matching logic for `match` expressions.
2
3use crate::ast::{Expr, Pattern};
4use crate::interpreter::error::RuntimeError;
5use crate::interpreter::value::Value;
6use crate::interpreter::Interpreter;
7
8impl Interpreter {
9    /// Returns `true` if the given pattern matches the given value.
10    pub(crate) fn match_pattern(
11        &self,
12        pattern: &Pattern,
13        value: &Value,
14    ) -> Result<bool, RuntimeError> {
15        match pattern {
16            Pattern::Wildcard | Pattern::Ident(_) => Ok(true),
17            Pattern::Literal(lit) => match_literal(lit, value),
18            Pattern::Range {
19                start,
20                end,
21                inclusive,
22            } => match_range(start, end, *inclusive, value),
23            Pattern::Or(patterns) => {
24                for p in patterns {
25                    if self.match_pattern(p, value)? {
26                        return Ok(true);
27                    }
28                }
29                Ok(false)
30            }
31        }
32    }
33
34    /// Binds variables introduced by the pattern (e.g. `Pattern::Ident`).
35    pub(crate) fn bind_pattern(
36        &mut self,
37        pattern: &Pattern,
38        value: &Value,
39    ) -> Result<(), RuntimeError> {
40        match pattern {
41            Pattern::Ident(name) => {
42                self.env.define(name.clone(), value.clone(), true);
43            }
44            Pattern::Or(patterns) => {
45                for p in patterns {
46                    if self.match_pattern(p, value)? {
47                        self.bind_pattern(p, value)?;
48                        break;
49                    }
50                }
51            }
52            _ => {} // Wildcards and literals don't bind
53        }
54        Ok(())
55    }
56}
57
58/// Compares a literal expression against a runtime value.
59fn match_literal(lit: &Expr, value: &Value) -> Result<bool, RuntimeError> {
60    Ok(match (lit, value) {
61        (Expr::IntLiteral(a), Value::Int(b)) => a == b,
62        (Expr::FloatLiteral(a), Value::Float(b)) => a == b,
63        (Expr::StringLiteral(a), Value::String(b)) => a == b,
64        (Expr::CharLiteral(a), Value::Char(b)) => a == b,
65        (Expr::BoolLiteral(a), Value::Bool(b)) => a == b,
66        _ => false,
67    })
68}
69
70/// Checks whether a value falls within a range pattern.
71fn match_range(
72    start: &Expr,
73    end: &Expr,
74    inclusive: bool,
75    value: &Value,
76) -> Result<bool, RuntimeError> {
77    Ok(match (start, end, value) {
78        (Expr::IntLiteral(s), Expr::IntLiteral(e), Value::Int(v)) => {
79            if inclusive {
80                v >= s && v <= e
81            } else {
82                v >= s && v < e
83            }
84        }
85        _ => false,
86    })
87}