1use crate::types::Type;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
8pub struct TypeEnv {
9 scopes: Vec<Scope>,
11 type_aliases: HashMap<String, Type>,
13 constants: HashMap<String, Type>,
15 state_vars: HashMap<String, Type>,
17 actions: HashMap<String, ActionSig>,
19 funcs: HashMap<String, FuncInfo>,
21}
22
23#[derive(Debug, Clone, Default)]
25struct Scope {
26 bindings: HashMap<String, Type>,
28}
29
30#[derive(Debug, Clone)]
32pub struct ActionSig {
33 pub params: Vec<(String, Type)>,
35}
36
37#[derive(Debug, Clone)]
39pub struct FuncInfo {
40 pub param_names: Vec<String>,
42 pub param_types: Vec<Type>,
44}
45
46impl TypeEnv {
47 pub fn new() -> Self {
49 Self {
50 scopes: vec![Scope::default()],
51 type_aliases: HashMap::new(),
52 constants: HashMap::new(),
53 state_vars: HashMap::new(),
54 actions: HashMap::new(),
55 funcs: HashMap::new(),
56 }
57 }
58
59 pub fn push_scope(&mut self) {
63 self.scopes.push(Scope::default());
64 }
65
66 pub fn pop_scope(&mut self) {
68 if self.scopes.len() > 1 {
69 self.scopes.pop();
70 }
71 }
72
73 pub fn bind_local(&mut self, name: String, ty: Type) {
77 if let Some(scope) = self.scopes.last_mut() {
78 scope.bindings.insert(name, ty);
79 }
80 }
81
82 pub fn lookup_local(&self, name: &str) -> Option<&Type> {
84 for scope in self.scopes.iter().rev() {
85 if let Some(ty) = scope.bindings.get(name) {
86 return Some(ty);
87 }
88 }
89 None
90 }
91
92 pub fn define_type_alias(&mut self, name: String, ty: Type) {
96 self.type_aliases.insert(name, ty);
97 }
98
99 pub fn lookup_type_alias(&self, name: &str) -> Option<&Type> {
101 self.type_aliases.get(name)
102 }
103
104 pub fn resolve_type(&self, ty: &Type) -> Type {
107 match ty {
108 Type::Named(name) => {
109 if let Some(aliased) = self.lookup_type_alias(name) {
110 self.resolve_type(aliased)
111 } else {
112 ty.clone()
113 }
114 }
115 Type::Set(inner) => Type::Set(Box::new(self.resolve_type(inner))),
116 Type::Seq(inner) => Type::Seq(Box::new(self.resolve_type(inner))),
117 Type::Option(inner) => Type::Option(Box::new(self.resolve_type(inner))),
118 Type::Fn(k, v) => Type::Fn(
119 Box::new(self.resolve_type(k)),
120 Box::new(self.resolve_type(v)),
121 ),
122 _ => ty.clone(),
123 }
124 }
125
126 pub fn define_const(&mut self, name: String, ty: Type) {
130 self.constants.insert(name, ty);
131 }
132
133 pub fn lookup_const(&self, name: &str) -> Option<&Type> {
135 self.constants.get(name)
136 }
137
138 pub fn define_var(&mut self, name: String, ty: Type) {
142 self.state_vars.insert(name, ty);
143 }
144
145 pub fn lookup_var(&self, name: &str) -> Option<&Type> {
147 self.state_vars.get(name)
148 }
149
150 pub fn state_var_names(&self) -> impl Iterator<Item = &str> {
152 self.state_vars.keys().map(|s| s.as_str())
153 }
154
155 pub fn define_action(&mut self, name: String, sig: ActionSig) {
159 self.actions.insert(name, sig);
160 }
161
162 pub fn lookup_action(&self, name: &str) -> Option<&ActionSig> {
164 self.actions.get(name)
165 }
166
167 pub fn define_func(&mut self, name: String, param_names: Vec<String>, param_types: Vec<Type>) {
171 self.funcs.insert(
172 name,
173 FuncInfo {
174 param_names,
175 param_types,
176 },
177 );
178 }
179
180 pub fn lookup_func(&self, name: &str) -> Option<&FuncInfo> {
182 self.funcs.get(name)
183 }
184
185 pub fn lookup_ident(&self, name: &str) -> Option<&Type> {
189 if let Some(ty) = self.lookup_local(name) {
191 return Some(ty);
192 }
193 if let Some(ty) = self.lookup_const(name) {
195 return Some(ty);
196 }
197 self.lookup_var(name)
199 }
200}
201
202impl Default for TypeEnv {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn test_scope_shadowing() {
214 let mut env = TypeEnv::new();
215 env.bind_local("x".to_string(), Type::Nat);
216
217 env.push_scope();
218 env.bind_local("x".to_string(), Type::Bool);
219
220 assert_eq!(env.lookup_local("x"), Some(&Type::Bool));
221
222 env.pop_scope();
223 assert_eq!(env.lookup_local("x"), Some(&Type::Nat));
224 }
225
226 #[test]
227 fn test_type_resolution() {
228 let mut env = TypeEnv::new();
229 env.define_type_alias("Counter".to_string(), Type::Nat);
230 env.define_type_alias(
231 "CounterSet".to_string(),
232 Type::Set(Box::new(Type::Named("Counter".to_string()))),
233 );
234
235 let resolved = env.resolve_type(&Type::Named("CounterSet".to_string()));
236 assert_eq!(resolved, Type::Set(Box::new(Type::Nat)));
237 }
238
239 #[test]
240 fn test_unified_lookup() {
241 let mut env = TypeEnv::new();
242 env.define_const("MAX".to_string(), Type::Nat);
243 env.define_var("count".to_string(), Type::Int);
244 env.bind_local("x".to_string(), Type::Bool);
245
246 assert_eq!(env.lookup_ident("MAX"), Some(&Type::Nat));
247 assert_eq!(env.lookup_ident("count"), Some(&Type::Int));
248 assert_eq!(env.lookup_ident("x"), Some(&Type::Bool));
249 assert_eq!(env.lookup_ident("unknown"), None);
250 }
251}