Skip to main content

specl_types/
env.rs

1//! Type environment and scope management.
2
3use crate::types::Type;
4use std::collections::HashMap;
5
6/// A type environment for tracking variable and type bindings.
7#[derive(Debug, Clone)]
8pub struct TypeEnv {
9    /// Stack of scopes, innermost last.
10    scopes: Vec<Scope>,
11    /// Type aliases (global).
12    type_aliases: HashMap<String, Type>,
13    /// Constants (global).
14    constants: HashMap<String, Type>,
15    /// State variables (global).
16    state_vars: HashMap<String, Type>,
17    /// Actions and their signatures.
18    actions: HashMap<String, ActionSig>,
19    /// User-defined functions.
20    funcs: HashMap<String, FuncInfo>,
21}
22
23/// A scope containing local variable bindings.
24#[derive(Debug, Clone, Default)]
25struct Scope {
26    /// Variable name to type mapping.
27    bindings: HashMap<String, Type>,
28}
29
30/// An action signature.
31#[derive(Debug, Clone)]
32pub struct ActionSig {
33    /// Parameter names and types.
34    pub params: Vec<(String, Type)>,
35}
36
37/// A user-defined function signature.
38#[derive(Debug, Clone)]
39pub struct FuncInfo {
40    /// Parameter names.
41    pub param_names: Vec<String>,
42    /// Parameter types (may be type variables for polymorphism).
43    pub param_types: Vec<Type>,
44}
45
46impl TypeEnv {
47    /// Create a new empty type environment with a global scope.
48    pub fn new() -> Self {
49        Self {
50            scopes: vec![Scope::default()],
51            type_aliases: HashMap::new(),
52            constants: HashMap::new(),
53            state_vars: HashMap::new(),
54            actions: HashMap::new(),
55            funcs: HashMap::new(),
56        }
57    }
58
59    // === Scope management ===
60
61    /// Enter a new scope.
62    pub fn push_scope(&mut self) {
63        self.scopes.push(Scope::default());
64    }
65
66    /// Exit the current scope.
67    pub fn pop_scope(&mut self) {
68        if self.scopes.len() > 1 {
69            self.scopes.pop();
70        }
71    }
72
73    // === Local variables ===
74
75    /// Bind a local variable in the current scope.
76    pub fn bind_local(&mut self, name: String, ty: Type) {
77        if let Some(scope) = self.scopes.last_mut() {
78            scope.bindings.insert(name, ty);
79        }
80    }
81
82    /// Look up a local variable, searching from innermost scope.
83    pub fn lookup_local(&self, name: &str) -> Option<&Type> {
84        for scope in self.scopes.iter().rev() {
85            if let Some(ty) = scope.bindings.get(name) {
86                return Some(ty);
87            }
88        }
89        None
90    }
91
92    // === Type aliases ===
93
94    /// Define a type alias.
95    pub fn define_type_alias(&mut self, name: String, ty: Type) {
96        self.type_aliases.insert(name, ty);
97    }
98
99    /// Look up a type alias.
100    pub fn lookup_type_alias(&self, name: &str) -> Option<&Type> {
101        self.type_aliases.get(name)
102    }
103
104    /// Resolve a named type to its actual type.
105    /// Returns the type itself if it's not an alias.
106    pub fn resolve_type(&self, ty: &Type) -> Type {
107        match ty {
108            Type::Named(name) => {
109                if let Some(aliased) = self.lookup_type_alias(name) {
110                    self.resolve_type(aliased)
111                } else {
112                    ty.clone()
113                }
114            }
115            Type::Set(inner) => Type::Set(Box::new(self.resolve_type(inner))),
116            Type::Seq(inner) => Type::Seq(Box::new(self.resolve_type(inner))),
117            Type::Option(inner) => Type::Option(Box::new(self.resolve_type(inner))),
118            Type::Fn(k, v) => Type::Fn(
119                Box::new(self.resolve_type(k)),
120                Box::new(self.resolve_type(v)),
121            ),
122            _ => ty.clone(),
123        }
124    }
125
126    // === Constants ===
127
128    /// Define a constant.
129    pub fn define_const(&mut self, name: String, ty: Type) {
130        self.constants.insert(name, ty);
131    }
132
133    /// Look up a constant.
134    pub fn lookup_const(&self, name: &str) -> Option<&Type> {
135        self.constants.get(name)
136    }
137
138    // === State variables ===
139
140    /// Define a state variable.
141    pub fn define_var(&mut self, name: String, ty: Type) {
142        self.state_vars.insert(name, ty);
143    }
144
145    /// Look up a state variable.
146    pub fn lookup_var(&self, name: &str) -> Option<&Type> {
147        self.state_vars.get(name)
148    }
149
150    /// Get all state variable names.
151    pub fn state_var_names(&self) -> impl Iterator<Item = &str> {
152        self.state_vars.keys().map(|s| s.as_str())
153    }
154
155    // === Actions ===
156
157    /// Define an action.
158    pub fn define_action(&mut self, name: String, sig: ActionSig) {
159        self.actions.insert(name, sig);
160    }
161
162    /// Look up an action.
163    pub fn lookup_action(&self, name: &str) -> Option<&ActionSig> {
164        self.actions.get(name)
165    }
166
167    // === Functions ===
168
169    /// Define a user-defined function.
170    pub fn define_func(&mut self, name: String, param_names: Vec<String>, param_types: Vec<Type>) {
171        self.funcs.insert(
172            name,
173            FuncInfo {
174                param_names,
175                param_types,
176            },
177        );
178    }
179
180    /// Look up a function.
181    pub fn lookup_func(&self, name: &str) -> Option<&FuncInfo> {
182        self.funcs.get(name)
183    }
184
185    // === Unified lookup ===
186
187    /// Look up any identifier (local, const, or var).
188    pub fn lookup_ident(&self, name: &str) -> Option<&Type> {
189        // Check local bindings first
190        if let Some(ty) = self.lookup_local(name) {
191            return Some(ty);
192        }
193        // Then constants
194        if let Some(ty) = self.lookup_const(name) {
195            return Some(ty);
196        }
197        // Then state variables
198        self.lookup_var(name)
199    }
200}
201
202impl Default for TypeEnv {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn test_scope_shadowing() {
214        let mut env = TypeEnv::new();
215        env.bind_local("x".to_string(), Type::Nat);
216
217        env.push_scope();
218        env.bind_local("x".to_string(), Type::Bool);
219
220        assert_eq!(env.lookup_local("x"), Some(&Type::Bool));
221
222        env.pop_scope();
223        assert_eq!(env.lookup_local("x"), Some(&Type::Nat));
224    }
225
226    #[test]
227    fn test_type_resolution() {
228        let mut env = TypeEnv::new();
229        env.define_type_alias("Counter".to_string(), Type::Nat);
230        env.define_type_alias(
231            "CounterSet".to_string(),
232            Type::Set(Box::new(Type::Named("Counter".to_string()))),
233        );
234
235        let resolved = env.resolve_type(&Type::Named("CounterSet".to_string()));
236        assert_eq!(resolved, Type::Set(Box::new(Type::Nat)));
237    }
238
239    #[test]
240    fn test_unified_lookup() {
241        let mut env = TypeEnv::new();
242        env.define_const("MAX".to_string(), Type::Nat);
243        env.define_var("count".to_string(), Type::Int);
244        env.bind_local("x".to_string(), Type::Bool);
245
246        assert_eq!(env.lookup_ident("MAX"), Some(&Type::Nat));
247        assert_eq!(env.lookup_ident("count"), Some(&Type::Int));
248        assert_eq!(env.lookup_ident("x"), Some(&Type::Bool));
249        assert_eq!(env.lookup_ident("unknown"), None);
250    }
251}