1use serde::{Deserialize, Serialize};
4use shape_ast::ast::Span;
5use shape_ast::ast::TypeAnnotation;
6use shape_ast::error::{Result, ShapeError, span_to_location};
7use std::collections::HashMap;
8
9use super::types::Type;
10use crate::type_system::environment::TypeAliasEntry;
11use shape_ast::ast::{EnumDef, VarKind};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum Symbol {
16 Variable {
18 ty: Type,
19 kind: VarKind,
20 is_initialized: bool,
21 },
22 Function {
24 params: Vec<Type>,
25 returns: Type,
26 defaults: Vec<bool>,
27 },
28 Module {
30 exports: Vec<String>,
31 type_annotation: TypeAnnotation,
32 },
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct SymbolTable {
38 scopes: Vec<Scope>,
40 enums: HashMap<String, EnumDef>,
42 type_aliases: HashMap<String, TypeAliasEntry>,
44 source: Option<String>,
46 allow_redefinition: bool,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52struct Scope {
53 symbols: HashMap<String, Symbol>,
55}
56
57impl SymbolTable {
58 pub fn new() -> Self {
60 Self {
61 scopes: vec![Scope::new()],
62 enums: HashMap::new(),
63 type_aliases: HashMap::new(),
64 source: None,
65 allow_redefinition: false,
66 }
67 }
68
69 pub fn set_source(&mut self, source: String) {
71 self.source = Some(source);
72 }
73
74 pub fn set_allow_redefinition(&mut self, allow: bool) {
76 self.allow_redefinition = allow;
77 }
78
79 fn error_at(&self, span: Span, message: impl Into<String>) -> ShapeError {
81 let location = self
82 .source
83 .as_ref()
84 .map(|src| span_to_location(src, span, None));
85 ShapeError::SemanticError {
86 message: message.into(),
87 location,
88 }
89 }
90
91 pub fn push_scope(&mut self) {
93 self.scopes.push(Scope::new());
94 }
95
96 pub fn pop_scope(&mut self) {
98 if self.scopes.len() > 1 {
99 self.scopes.pop();
100 }
101 }
102
103 pub fn define_variable(
105 &mut self,
106 name: &str,
107 ty: Type,
108 kind: VarKind,
109 is_initialized: bool,
110 ) -> Result<()> {
111 self.define_variable_at(name, ty, kind, is_initialized, Span::DUMMY)
112 }
113
114 pub fn define_variable_at(
116 &mut self,
117 name: &str,
118 ty: Type,
119 kind: VarKind,
120 is_initialized: bool,
121 span: Span,
122 ) -> Result<()> {
123 let scope = self.scopes.last_mut().unwrap();
124
125 if scope.symbols.contains_key(name) {
126 if self.allow_redefinition {
127 scope.symbols.insert(
128 name.to_string(),
129 Symbol::Variable {
130 ty,
131 kind,
132 is_initialized,
133 },
134 );
135 return Ok(());
136 }
137 return Err(self.error_at(
138 span,
139 format!("Variable '{}' is already defined in this scope", name),
140 ));
141 }
142
143 scope.symbols.insert(
144 name.to_string(),
145 Symbol::Variable {
146 ty,
147 kind,
148 is_initialized,
149 },
150 );
151 Ok(())
152 }
153
154 pub fn define_function(&mut self, name: &str, params: Vec<Type>, returns: Type) -> Result<()> {
156 let defaults = vec![false; params.len()];
157 self.define_function_at_with_defaults(name, params, returns, defaults, Span::DUMMY)
158 }
159
160 pub fn define_function_with_defaults(
162 &mut self,
163 name: &str,
164 params: Vec<Type>,
165 returns: Type,
166 defaults: Vec<bool>,
167 ) -> Result<()> {
168 self.define_function_at_with_defaults(name, params, returns, defaults, Span::DUMMY)
169 }
170
171 pub fn define_function_at(
173 &mut self,
174 name: &str,
175 params: Vec<Type>,
176 returns: Type,
177 span: Span,
178 ) -> Result<()> {
179 let defaults = vec![false; params.len()];
180 self.define_function_at_with_defaults(name, params, returns, defaults, span)
181 }
182
183 fn define_function_at_with_defaults(
184 &mut self,
185 name: &str,
186 params: Vec<Type>,
187 returns: Type,
188 defaults: Vec<bool>,
189 span: Span,
190 ) -> Result<()> {
191 let scope = self.scopes.last_mut().unwrap();
192
193 if scope.symbols.contains_key(name) {
194 return Err(self.error_at(
195 span,
196 format!("Function '{}' is already defined in this scope", name),
197 ));
198 }
199
200 scope.symbols.insert(
201 name.to_string(),
202 Symbol::Function {
203 params,
204 returns,
205 defaults,
206 },
207 );
208 Ok(())
209 }
210
211 pub fn define_enum(&mut self, enum_def: EnumDef) -> Result<()> {
213 if self.enums.contains_key(&enum_def.name) {
214 return Err(self.error_at(
215 Span::DUMMY,
216 format!("Enum '{}' is already defined", enum_def.name),
217 ));
218 }
219
220 self.enums.insert(enum_def.name.clone(), enum_def);
221 Ok(())
222 }
223
224 pub fn lookup_enum(&self, name: &str) -> Option<&EnumDef> {
226 self.enums.get(name)
227 }
228
229 pub fn define_type_alias(&mut self, name: &str, type_annotation: TypeAnnotation) -> Result<()> {
231 self.define_type_alias_at(name, type_annotation, None, Span::DUMMY)
232 }
233
234 pub fn define_type_alias_at(
236 &mut self,
237 name: &str,
238 type_annotation: TypeAnnotation,
239 meta_param_overrides: Option<HashMap<String, shape_ast::ast::Expr>>,
240 span: Span,
241 ) -> Result<()> {
242 if self.type_aliases.contains_key(name) {
243 return Err(self.error_at(span, format!("Type '{}' is already defined", name)));
244 }
245
246 self.type_aliases.insert(
247 name.to_string(),
248 TypeAliasEntry {
249 type_annotation,
250 meta_param_overrides,
251 },
252 );
253 Ok(())
254 }
255
256 pub fn lookup_type_alias(&self, name: &str) -> Option<&TypeAliasEntry> {
258 self.type_aliases.get(name)
259 }
260
261 pub fn lookup(&self, name: &str) -> Option<&Symbol> {
263 for scope in self.scopes.iter().rev() {
265 if let Some(symbol) = scope.symbols.get(name) {
266 return Some(symbol);
267 }
268 }
269
270 None
271 }
272
273 pub fn update_variable(&mut self, name: &str, new_ty: Type) -> Result<()> {
275 self.update_variable_at(name, new_ty, Span::DUMMY)
276 }
277
278 pub fn update_variable_at(&mut self, name: &str, new_ty: Type, span: Span) -> Result<()> {
280 for scope in self.scopes.iter_mut().rev() {
282 if let Some(symbol) = scope.symbols.get_mut(name) {
283 match symbol {
284 Symbol::Variable {
285 ty,
286 kind,
287 is_initialized,
288 } => {
289 if matches!(kind, VarKind::Const) && *is_initialized {
291 return Err(self.error_at(
292 span,
293 format!("Cannot reassign const variable '{}'", name),
294 ));
295 }
296
297 *ty = new_ty;
299 *is_initialized = true;
300 return Ok(());
301 }
302 _ => return Err(self.error_at(span, format!("'{}' is not a variable", name))),
303 }
304 }
305 }
306
307 Err(self.error_at(span, format!("Undefined variable: '{}'", name)))
308 }
309
310 pub fn lookup_variable(&self, name: &str) -> Option<(&Type, &VarKind, bool)> {
312 match self.lookup(name)? {
313 Symbol::Variable {
314 ty,
315 kind,
316 is_initialized,
317 } => Some((ty, kind, *is_initialized)),
318 _ => None,
319 }
320 }
321
322 pub fn lookup_function(&self, name: &str) -> Option<(&Vec<Type>, &Type, &Vec<bool>)> {
324 match self.lookup(name)? {
325 Symbol::Function {
326 params,
327 returns,
328 defaults,
329 } => Some((params, returns, defaults)),
330 _ => None,
331 }
332 }
333
334 pub fn define_module(
336 &mut self,
337 name: &str,
338 exports: Vec<String>,
339 type_annotation: TypeAnnotation,
340 ) -> Result<()> {
341 let root_scope = &mut self.scopes[0];
342 if root_scope.symbols.contains_key(name) {
343 return Ok(()); }
345 root_scope.symbols.insert(
346 name.to_string(),
347 Symbol::Module {
348 exports,
349 type_annotation,
350 },
351 );
352 Ok(())
353 }
354
355 pub fn lookup_module(&self, name: &str) -> Option<(&Vec<String>, &TypeAnnotation)> {
357 match self.lookup(name)? {
358 Symbol::Module {
359 exports,
360 type_annotation,
361 } => Some((exports, type_annotation)),
362 _ => None,
363 }
364 }
365
366 pub fn is_defined_in_current_scope(&self, name: &str) -> bool {
368 self.scopes.last().unwrap().symbols.contains_key(name)
369 }
370
371 pub fn iter_all_symbols(&self) -> impl Iterator<Item = (&str, &Symbol)> {
376 self.scopes
377 .iter()
378 .flat_map(|scope| scope.symbols.iter().map(|(k, v)| (k.as_str(), v)))
379 }
380
381 pub fn iter_type_aliases(&self) -> impl Iterator<Item = (&str, &TypeAliasEntry)> {
383 self.type_aliases.iter().map(|(k, v)| (k.as_str(), v))
384 }
385
386 pub fn iter_enums(&self) -> impl Iterator<Item = (&str, &EnumDef)> {
388 self.enums.iter().map(|(k, v)| (k.as_str(), v))
389 }
390}
391
392impl Scope {
393 fn new() -> Self {
395 Self {
396 symbols: HashMap::new(),
397 }
398 }
399}
400
401impl Default for SymbolTable {
402 fn default() -> Self {
403 Self::new()
404 }
405}