1use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use shape_ast::ast::{FunctionDef, VarKind};
7use shape_value::ValueWord;
8
9#[derive(Debug, Clone)]
11pub struct Closure {
12 pub function: Arc<FunctionDef>,
14 pub captured_env: CapturedEnvironment,
16}
17
18impl PartialEq for Closure {
19 fn eq(&self, other: &Self) -> bool {
20 Arc::ptr_eq(&self.function, &other.function) && self.captured_env == other.captured_env
23 }
24}
25
26#[derive(Debug, Clone)]
28pub struct CapturedEnvironment {
29 pub bindings: HashMap<String, CapturedBinding>,
31 pub parent: Option<Box<CapturedEnvironment>>,
33}
34
35#[derive(Debug, Clone)]
37pub struct CapturedBinding {
38 pub value: ValueWord,
40 pub kind: VarKind,
42 pub is_mutable: bool,
44}
45
46impl PartialEq for CapturedBinding {
47 fn eq(&self, other: &Self) -> bool {
48 self.kind == other.kind
49 && self.is_mutable == other.is_mutable
50 && self.value.vw_equals(&other.value)
51 }
52}
53
54impl PartialEq for CapturedEnvironment {
55 fn eq(&self, other: &Self) -> bool {
56 self.bindings == other.bindings && self.parent == other.parent
57 }
58}
59
60impl Default for CapturedEnvironment {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl CapturedEnvironment {
67 pub fn new() -> Self {
69 Self {
70 bindings: HashMap::new(),
71 parent: None,
72 }
73 }
74
75 pub fn with_parent(parent: CapturedEnvironment) -> Self {
77 Self {
78 bindings: HashMap::new(),
79 parent: Some(Box::new(parent)),
80 }
81 }
82
83 pub fn capture(&mut self, name: String, value: ValueWord, kind: VarKind) {
85 let is_mutable = matches!(kind, VarKind::Var);
86 self.bindings.insert(
87 name,
88 CapturedBinding {
89 value,
90 kind,
91 is_mutable,
92 },
93 );
94 }
95
96 pub fn lookup(&self, name: &str) -> Option<&CapturedBinding> {
98 self.bindings
99 .get(name)
100 .or_else(|| self.parent.as_ref().and_then(|p| p.lookup(name)))
101 }
102
103 pub fn lookup_mut(&mut self, name: &str) -> Option<&mut CapturedBinding> {
105 if self.bindings.contains_key(name) {
106 self.bindings.get_mut(name)
107 } else if let Some(parent) = &mut self.parent {
108 parent.lookup_mut(name)
109 } else {
110 None
111 }
112 }
113
114 pub fn all_captured_names(&self) -> Vec<String> {
116 let mut names: Vec<String> = self.bindings.keys().cloned().collect();
117
118 if let Some(parent) = &self.parent {
119 for name in parent.all_captured_names() {
120 if !names.contains(&name) {
121 names.push(name);
122 }
123 }
124 }
125
126 names
127 }
128}
129
130pub struct EnvironmentAnalyzer {
132 scope_stack: Vec<HashMap<String, bool>>, captured_vars: HashMap<String, usize>, mutated_captures: HashSet<String>,
138 function_scope_level: usize,
143}
144
145impl Default for EnvironmentAnalyzer {
146 fn default() -> Self {
147 Self {
148 scope_stack: vec![HashMap::new()],
149 captured_vars: HashMap::new(),
150 mutated_captures: HashSet::new(),
151 function_scope_level: 1,
152 }
153 }
154}
155
156impl EnvironmentAnalyzer {
157 pub fn new() -> Self {
158 Self {
159 scope_stack: vec![HashMap::new()], captured_vars: HashMap::new(),
161 mutated_captures: HashSet::new(),
162 function_scope_level: 1, }
164 }
165
166 pub fn enter_scope(&mut self) {
168 self.scope_stack.push(HashMap::new());
169 }
170
171 pub fn exit_scope(&mut self) {
173 self.scope_stack.pop();
174 }
175
176 pub fn define_variable(&mut self, name: &str) {
178 if let Some(current_scope) = self.scope_stack.last_mut() {
179 current_scope.insert(name.to_string(), true);
180 }
181 }
182
183 pub fn check_variable_reference(&mut self, name: &str) {
189 for (level, scope) in self.scope_stack.iter().enumerate().rev() {
191 if scope.contains_key(name) {
192 if level < self.function_scope_level {
195 self.captured_vars.insert(name.to_string(), level);
196 }
197 return;
198 }
199 }
200 }
201
202 pub fn mark_capture_mutated(&mut self, name: &str) {
204 for (level, scope) in self.scope_stack.iter().enumerate().rev() {
206 if scope.contains_key(name) {
207 if level < self.function_scope_level {
208 self.captured_vars.insert(name.to_string(), level);
209 self.mutated_captures.insert(name.to_string());
210 }
211 return;
212 }
213 }
214 }
215
216 pub fn get_captured_vars(&self) -> Vec<String> {
218 self.captured_vars.keys().cloned().collect()
219 }
220
221 pub fn get_mutated_captures(&self) -> HashSet<String> {
223 self.mutated_captures.clone()
224 }
225
226 pub fn analyze_function(function: &FunctionDef, outer_scope_vars: &[String]) -> Vec<String> {
228 let mut analyzer = Self::new();
229
230 for var in outer_scope_vars {
232 analyzer.define_variable(var);
233 }
234
235 analyzer.enter_scope();
237 analyzer.function_scope_level = analyzer.scope_stack.len() - 1;
240
241 for param in &function.params {
243 for name in param.get_identifiers() {
244 analyzer.define_variable(&name);
245 }
246 }
247
248 for stmt in &function.body {
250 analyzer.analyze_statement(stmt);
251 }
252
253 analyzer.get_captured_vars()
254 }
255
256 pub fn analyze_function_with_mutability(
259 function: &FunctionDef,
260 outer_scope_vars: &[String],
261 ) -> (Vec<String>, HashSet<String>) {
262 let mut analyzer = Self::new();
263
264 for var in outer_scope_vars {
266 analyzer.define_variable(var);
267 }
268
269 analyzer.enter_scope();
271 analyzer.function_scope_level = analyzer.scope_stack.len() - 1;
272
273 for param in &function.params {
275 for name in param.get_identifiers() {
276 analyzer.define_variable(&name);
277 }
278 }
279
280 for stmt in &function.body {
282 analyzer.analyze_statement(stmt);
283 }
284
285 (
286 analyzer.get_captured_vars(),
287 analyzer.get_mutated_captures(),
288 )
289 }
290
291 fn analyze_statement(&mut self, stmt: &shape_ast::ast::Statement) {
293 use shape_ast::ast::Statement;
294
295 match stmt {
296 Statement::Return(expr, _) => {
297 if let Some(expr) = expr {
298 self.analyze_expr(expr);
299 }
300 }
301 Statement::Expression(expr, _) => {
302 self.analyze_expr(expr);
303 }
304 Statement::VariableDecl(decl, _) => {
305 if let Some(value) = &decl.value {
307 self.analyze_expr(value);
308 }
309 if let Some(name) = decl.pattern.as_identifier() {
311 self.define_variable(name);
312 } else {
313 for name in decl.pattern.get_identifiers() {
315 self.define_variable(&name);
316 }
317 }
318 }
319 Statement::Assignment(assign, _) => {
320 self.analyze_expr(&assign.value);
321 if let Some(name) = assign.pattern.as_identifier() {
322 self.mark_capture_mutated(name);
324 self.check_variable_reference(name);
325 } else {
326 for name in assign.pattern.get_identifiers() {
328 self.mark_capture_mutated(&name);
329 self.check_variable_reference(&name);
330 }
331 }
332 }
333 Statement::If(if_stmt, _) => {
334 self.analyze_expr(&if_stmt.condition);
335 self.enter_scope();
336 for stmt in &if_stmt.then_body {
337 self.analyze_statement(stmt);
338 }
339 self.exit_scope();
340
341 if let Some(else_body) = &if_stmt.else_body {
342 self.enter_scope();
343 for stmt in else_body {
344 self.analyze_statement(stmt);
345 }
346 self.exit_scope();
347 }
348 }
349 Statement::While(while_loop, _) => {
350 self.analyze_expr(&while_loop.condition);
351 self.enter_scope();
352 for stmt in &while_loop.body {
353 self.analyze_statement(stmt);
354 }
355 self.exit_scope();
356 }
357 Statement::For(for_loop, _) => {
358 self.enter_scope();
359
360 match &for_loop.init {
362 shape_ast::ast::ForInit::ForIn { pattern, iter } => {
363 self.analyze_expr(iter);
364 for name in pattern.get_identifiers() {
366 self.define_variable(&name);
367 }
368 }
369 shape_ast::ast::ForInit::ForC {
370 init: _,
371 condition,
372 update,
373 } => {
374 self.analyze_expr(condition);
377 self.analyze_expr(update);
378 }
379 }
380
381 for stmt in &for_loop.body {
382 self.analyze_statement(stmt);
383 }
384
385 self.exit_scope();
386 }
387 Statement::Break(_) | Statement::Continue(_) => {
388 }
390 Statement::Extend(ext, _) => {
391 for method in &ext.methods {
392 self.enter_scope();
393 self.define_variable("self");
394 for param in &method.params {
395 for name in param.get_identifiers() {
396 self.define_variable(&name);
397 }
398 }
399 for stmt in &method.body {
400 self.analyze_statement(stmt);
401 }
402 self.exit_scope();
403 }
404 }
405 Statement::RemoveTarget(_) => {}
406 Statement::SetParamType { .. }
407 | Statement::SetReturnType { .. }
408 | Statement::SetReturnExpr { .. } => {}
409 Statement::ReplaceModuleExpr { expression, .. } => {
410 self.analyze_expr(expression);
411 }
412 Statement::ReplaceBodyExpr { expression, .. } => {
413 self.analyze_expr(expression);
414 }
415 Statement::ReplaceBody { body, .. } => {
416 for stmt in body {
417 self.analyze_statement(stmt);
418 }
419 }
420 }
421 }
422
423 fn analyze_expr(&mut self, expr: &shape_ast::ast::Expr) {
425 use shape_ast::ast::Expr;
426
427 match expr {
428 Expr::Identifier(name, _) => {
429 self.check_variable_reference(name);
430 }
431 Expr::Literal(..)
432 | Expr::DataRef(..)
433 | Expr::DataDateTimeRef(..)
434 | Expr::TimeRef(..)
435 | Expr::PatternRef(..) => {
436 }
438 Expr::DataRelativeAccess {
439 reference,
440 index: _,
441 ..
442 } => {
443 self.analyze_expr(reference);
444 }
446 Expr::BinaryOp { left, right, .. } => {
447 self.analyze_expr(left);
448 self.analyze_expr(right);
449 }
450 Expr::FuzzyComparison { left, right, .. } => {
451 self.analyze_expr(left);
452 self.analyze_expr(right);
453 }
454 Expr::UnaryOp { operand, .. } => {
455 self.analyze_expr(operand);
456 }
457 Expr::FunctionCall { name, args, .. } => {
458 self.check_variable_reference(name);
461 for arg in args {
462 self.analyze_expr(arg);
463 }
464 }
465 Expr::EnumConstructor { payload, .. } => {
466 use shape_ast::ast::EnumConstructorPayload;
467 match payload {
468 EnumConstructorPayload::Unit => {}
469 EnumConstructorPayload::Tuple(values) => {
470 for value in values {
471 self.analyze_expr(value);
472 }
473 }
474 EnumConstructorPayload::Struct(fields) => {
475 for (_, value) in fields {
476 self.analyze_expr(value);
477 }
478 }
479 }
480 }
481 Expr::PropertyAccess { object, .. } => {
482 self.analyze_expr(object);
483 }
484 Expr::Conditional {
485 condition,
486 then_expr,
487 else_expr,
488 ..
489 } => {
490 self.analyze_expr(condition);
491 self.analyze_expr(then_expr);
492 if let Some(else_e) = else_expr {
493 self.analyze_expr(else_e);
494 }
495 }
496 Expr::Array(elements, _) => {
497 for elem in elements {
498 self.analyze_expr(elem);
499 }
500 }
501 Expr::ListComprehension(comp, _) => {
502 self.enter_scope();
504
505 for clause in &comp.clauses {
507 for name in clause.pattern.get_identifiers() {
509 self.define_variable(&name);
510 }
511
512 self.analyze_expr(&clause.iterable);
514
515 if let Some(filter) = &clause.filter {
517 self.analyze_expr(filter);
518 }
519 }
520
521 self.analyze_expr(&comp.element);
523
524 self.exit_scope();
525 }
526 Expr::Object(entries, _) => {
527 use shape_ast::ast::ObjectEntry;
528 for entry in entries {
529 match entry {
530 ObjectEntry::Field { value, .. } => self.analyze_expr(value),
531 ObjectEntry::Spread(spread_expr) => self.analyze_expr(spread_expr),
532 }
533 }
534 }
535 Expr::IndexAccess {
536 object,
537 index,
538 end_index,
539 ..
540 } => {
541 self.analyze_expr(object);
542 self.analyze_expr(index);
543 if let Some(end) = end_index {
544 self.analyze_expr(end);
545 }
546 }
547 Expr::Block(block, _) => {
548 self.enter_scope();
549 for item in &block.items {
550 match item {
551 shape_ast::ast::BlockItem::VariableDecl(decl) => {
552 if let Some(value) = &decl.value {
553 self.analyze_expr(value);
554 }
555 if let Some(name) = decl.pattern.as_identifier() {
556 self.define_variable(name);
557 }
558 }
559 shape_ast::ast::BlockItem::Assignment(assign) => {
560 self.analyze_expr(&assign.value);
561 if let Some(name) = assign.pattern.as_identifier() {
562 self.mark_capture_mutated(name);
563 self.check_variable_reference(name);
564 } else {
565 for name in assign.pattern.get_identifiers() {
566 self.mark_capture_mutated(&name);
567 self.check_variable_reference(&name);
568 }
569 }
570 }
571 shape_ast::ast::BlockItem::Statement(stmt) => {
572 self.analyze_statement(stmt);
573 }
574 shape_ast::ast::BlockItem::Expression(expr) => {
575 self.analyze_expr(expr);
576 }
577 }
578 }
579 self.exit_scope();
580 }
581 Expr::TypeAssertion { expr, .. } => {
582 self.analyze_expr(expr);
583 }
584 Expr::InstanceOf { expr, .. } => {
585 self.analyze_expr(expr);
586 }
587 Expr::FunctionExpr {
588 params,
589 return_type: _,
590 body,
591 ..
592 } => {
593 let saved_function_scope_level = self.function_scope_level;
595 self.enter_scope();
596 self.function_scope_level = self.scope_stack.len() - 1;
597
598 for param in params {
599 for name in param.get_identifiers() {
600 self.define_variable(&name);
601 }
602 }
603
604 for stmt in body {
605 self.analyze_statement(stmt);
606 }
607
608 self.exit_scope();
609 self.function_scope_level = saved_function_scope_level;
610
611 self.captured_vars
615 .retain(|_, level| *level < saved_function_scope_level);
616 self.mutated_captures
617 .retain(|name| self.captured_vars.contains_key(name));
618 }
619 Expr::Duration(..) => {
620 }
622
623 Expr::If(if_expr, _) => {
625 self.analyze_expr(&if_expr.condition);
626 self.analyze_expr(&if_expr.then_branch);
627 if let Some(else_branch) = &if_expr.else_branch {
628 self.analyze_expr(else_branch);
629 }
630 }
631
632 Expr::While(while_expr, _) => {
633 self.analyze_expr(&while_expr.condition);
634 self.analyze_expr(&while_expr.body);
635 }
636
637 Expr::For(for_expr, _) => {
638 self.enter_scope();
639 self.analyze_pattern(&for_expr.pattern);
641 self.analyze_expr(&for_expr.iterable);
642 self.analyze_expr(&for_expr.body);
643 self.exit_scope();
644 }
645
646 Expr::Loop(loop_expr, _) => {
647 self.analyze_expr(&loop_expr.body);
648 }
649
650 Expr::Let(let_expr, _) => {
651 if let Some(value) = &let_expr.value {
652 self.analyze_expr(value);
653 }
654 self.enter_scope();
655 self.analyze_pattern(&let_expr.pattern);
656 self.analyze_expr(&let_expr.body);
657 self.exit_scope();
658 }
659
660 Expr::Assign(assign, _) => {
661 self.analyze_expr(&assign.value);
662 self.analyze_expr(&assign.target);
663 }
664
665 Expr::Break(value, _) => {
666 if let Some(val) = value {
667 self.analyze_expr(val);
668 }
669 }
670
671 Expr::Continue(_) => {
672 }
674
675 Expr::Return(value, _) => {
676 if let Some(val) = value {
677 self.analyze_expr(val);
678 }
679 }
680
681 Expr::MethodCall { receiver, args, .. } => {
682 self.analyze_expr(receiver);
683 for arg in args {
684 self.analyze_expr(arg);
685 }
686 }
687
688 Expr::Match(match_expr, _) => {
689 self.analyze_expr(&match_expr.scrutinee);
690 for arm in &match_expr.arms {
691 self.enter_scope();
692 self.analyze_pattern(&arm.pattern);
693 if let Some(guard) = &arm.guard {
694 self.analyze_expr(guard);
695 }
696 self.analyze_expr(&arm.body);
697 self.exit_scope();
698 }
699 }
700
701 Expr::Unit(_) => {
702 }
704
705 Expr::Spread(inner_expr, _) => {
706 self.analyze_expr(inner_expr);
708 }
709
710 Expr::DateTime(..) => {
711 }
713 Expr::Range { start, end, .. } => {
714 if let Some(s) = start {
716 self.analyze_expr(s);
717 }
718 if let Some(e) = end {
719 self.analyze_expr(e);
720 }
721 }
722
723 Expr::TimeframeContext { expr, .. } => {
724 self.analyze_expr(expr);
726 }
727
728 Expr::TryOperator(inner, _) => {
729 self.analyze_expr(inner);
731 }
732 Expr::UsingImpl { expr, .. } => {
733 self.analyze_expr(expr);
734 }
735
736 Expr::Await(inner, _) => {
737 self.analyze_expr(inner);
739 }
740
741 Expr::SimulationCall { params, .. } => {
742 for (_, value_expr) in params {
745 self.analyze_expr(value_expr);
746 }
747 }
748
749 Expr::WindowExpr(_, _) => {
750 }
752
753 Expr::FromQuery(from_query, _) => {
754 self.analyze_expr(&from_query.source);
756 for clause in &from_query.clauses {
757 match clause {
758 shape_ast::ast::QueryClause::Where(pred) => {
759 self.analyze_expr(pred);
760 }
761 shape_ast::ast::QueryClause::OrderBy(specs) => {
762 for spec in specs {
763 self.analyze_expr(&spec.key);
764 }
765 }
766 shape_ast::ast::QueryClause::GroupBy { element, key, .. } => {
767 self.analyze_expr(element);
768 self.analyze_expr(key);
769 }
770 shape_ast::ast::QueryClause::Join {
771 source,
772 left_key,
773 right_key,
774 ..
775 } => {
776 self.analyze_expr(source);
777 self.analyze_expr(left_key);
778 self.analyze_expr(right_key);
779 }
780 shape_ast::ast::QueryClause::Let { value, .. } => {
781 self.analyze_expr(value);
782 }
783 }
784 }
785 self.analyze_expr(&from_query.select);
786 }
787 Expr::StructLiteral { fields, .. } => {
788 for (_, value_expr) in fields {
789 self.analyze_expr(value_expr);
790 }
791 }
792 Expr::Join(join_expr, _) => {
793 for branch in &join_expr.branches {
794 self.analyze_expr(&branch.expr);
795 }
796 }
797 Expr::Annotated { target, .. } => {
798 self.analyze_expr(target);
799 }
800 Expr::AsyncLet(async_let, _) => {
801 self.analyze_expr(&async_let.expr);
802 }
803 Expr::AsyncScope(inner, _) => {
804 self.analyze_expr(inner);
805 }
806 Expr::Comptime(stmts, _) => {
807 for stmt in stmts {
808 self.analyze_statement(stmt);
809 }
810 }
811 Expr::ComptimeFor(cf, _) => {
812 self.analyze_expr(&cf.iterable);
813 for stmt in &cf.body {
814 self.analyze_statement(stmt);
815 }
816 }
817 Expr::Reference { expr: inner, .. } => {
818 self.analyze_expr(inner);
819 }
820 }
821 }
822
823 fn analyze_pattern(&mut self, pattern: &shape_ast::ast::Pattern) {
825 use shape_ast::ast::Pattern;
826
827 match pattern {
828 Pattern::Identifier(name) => {
829 self.define_variable(name);
830 }
831 Pattern::Typed { name, .. } => {
832 self.define_variable(name);
833 }
834 Pattern::Wildcard | Pattern::Literal(_) => {
835 }
837 Pattern::Array(patterns) => {
838 for p in patterns {
839 self.analyze_pattern(p);
840 }
841 }
842 Pattern::Object(fields) => {
843 for (_, p) in fields {
844 self.analyze_pattern(p);
845 }
846 }
847 Pattern::Constructor { fields, .. } => match fields {
848 shape_ast::ast::PatternConstructorFields::Unit => {}
849 shape_ast::ast::PatternConstructorFields::Tuple(patterns) => {
850 for p in patterns {
851 self.analyze_pattern(p);
852 }
853 }
854 shape_ast::ast::PatternConstructorFields::Struct(fields) => {
855 for (_, p) in fields {
856 self.analyze_pattern(p);
857 }
858 }
859 },
860 }
861 }
862}