1use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use shape_ast::ast::{FunctionDef, VarKind};
7use shape_value::KindedSlot;
8
9#[derive(Debug, Clone)]
11pub struct Closure {
12 pub function: Arc<FunctionDef>,
14 pub captured_env: CapturedEnvironment,
16}
17
18impl PartialEq for Closure {
19 fn eq(&self, other: &Self) -> bool {
20 Arc::ptr_eq(&self.function, &other.function) && self.captured_env == other.captured_env
23 }
24}
25
26#[derive(Debug, Clone)]
28pub struct CapturedEnvironment {
29 pub bindings: HashMap<String, CapturedBinding>,
31 pub parent: Option<Box<CapturedEnvironment>>,
33}
34
35#[derive(Debug, Clone)]
42pub struct CapturedBinding {
43 pub value: KindedSlot,
45 pub kind: VarKind,
47 pub is_mutable: bool,
49}
50
51impl PartialEq for CapturedBinding {
52 fn eq(&self, other: &Self) -> bool {
53 self.kind == other.kind
60 && self.is_mutable == other.is_mutable
61 && self.value.slot().raw() == other.value.slot().raw()
62 && self.value.kind() == other.value.kind()
63 }
64}
65
66impl PartialEq for CapturedEnvironment {
67 fn eq(&self, other: &Self) -> bool {
68 self.bindings == other.bindings && self.parent == other.parent
69 }
70}
71
72impl Default for CapturedEnvironment {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl CapturedEnvironment {
79 pub fn new() -> Self {
81 Self {
82 bindings: HashMap::new(),
83 parent: None,
84 }
85 }
86
87 pub fn with_parent(parent: CapturedEnvironment) -> Self {
89 Self {
90 bindings: HashMap::new(),
91 parent: Some(Box::new(parent)),
92 }
93 }
94
95 pub fn capture(&mut self, name: String, value: KindedSlot, kind: VarKind) {
97 let is_mutable = matches!(kind, VarKind::Var);
98 self.bindings.insert(
99 name,
100 CapturedBinding {
101 value,
102 kind,
103 is_mutable,
104 },
105 );
106 }
107
108 pub fn lookup(&self, name: &str) -> Option<&CapturedBinding> {
110 self.bindings
111 .get(name)
112 .or_else(|| self.parent.as_ref().and_then(|p| p.lookup(name)))
113 }
114
115 pub fn lookup_mut(&mut self, name: &str) -> Option<&mut CapturedBinding> {
117 if self.bindings.contains_key(name) {
118 self.bindings.get_mut(name)
119 } else if let Some(parent) = &mut self.parent {
120 parent.lookup_mut(name)
121 } else {
122 None
123 }
124 }
125
126 pub fn all_captured_names(&self) -> Vec<String> {
128 let mut names: Vec<String> = self.bindings.keys().cloned().collect();
129
130 if let Some(parent) = &self.parent {
131 for name in parent.all_captured_names() {
132 if !names.contains(&name) {
133 names.push(name);
134 }
135 }
136 }
137
138 names
139 }
140}
141
142pub struct EnvironmentAnalyzer {
144 scope_stack: Vec<HashMap<String, bool>>, captured_vars: HashMap<String, usize>, mutated_captures: HashSet<String>,
150 function_scope_level: usize,
155}
156
157impl Default for EnvironmentAnalyzer {
158 fn default() -> Self {
159 Self {
160 scope_stack: vec![HashMap::new()],
161 captured_vars: HashMap::new(),
162 mutated_captures: HashSet::new(),
163 function_scope_level: 1,
164 }
165 }
166}
167
168impl EnvironmentAnalyzer {
169 pub fn new() -> Self {
170 Self {
171 scope_stack: vec![HashMap::new()], captured_vars: HashMap::new(),
173 mutated_captures: HashSet::new(),
174 function_scope_level: 1, }
176 }
177
178 pub fn enter_scope(&mut self) {
180 self.scope_stack.push(HashMap::new());
181 }
182
183 pub fn exit_scope(&mut self) {
185 self.scope_stack.pop();
186 }
187
188 pub fn define_variable(&mut self, name: &str) {
190 if let Some(current_scope) = self.scope_stack.last_mut() {
191 current_scope.insert(name.to_string(), true);
192 }
193 }
194
195 pub fn check_variable_reference(&mut self, name: &str) {
201 for (level, scope) in self.scope_stack.iter().enumerate().rev() {
203 if scope.contains_key(name) {
204 if level < self.function_scope_level {
207 self.captured_vars.insert(name.to_string(), level);
208 }
209 return;
210 }
211 }
212 }
213
214 pub fn mark_capture_mutated(&mut self, name: &str) {
216 for (level, scope) in self.scope_stack.iter().enumerate().rev() {
218 if scope.contains_key(name) {
219 if level < self.function_scope_level {
220 self.captured_vars.insert(name.to_string(), level);
221 self.mutated_captures.insert(name.to_string());
222 }
223 return;
224 }
225 }
226 }
227
228 pub fn get_captured_vars(&self) -> Vec<String> {
230 self.captured_vars.keys().cloned().collect()
231 }
232
233 pub fn get_mutated_captures(&self) -> HashSet<String> {
235 self.mutated_captures.clone()
236 }
237
238 pub fn analyze_function(function: &FunctionDef, outer_scope_vars: &[String]) -> Vec<String> {
240 let mut analyzer = Self::new();
241
242 for var in outer_scope_vars {
244 analyzer.define_variable(var);
245 }
246
247 analyzer.enter_scope();
249 analyzer.function_scope_level = analyzer.scope_stack.len() - 1;
252
253 for param in &function.params {
255 for name in param.get_identifiers() {
256 analyzer.define_variable(&name);
257 }
258 }
259
260 for stmt in &function.body {
262 analyzer.analyze_statement(stmt);
263 }
264
265 analyzer.get_captured_vars()
266 }
267
268 pub fn analyze_function_with_mutability(
271 function: &FunctionDef,
272 outer_scope_vars: &[String],
273 ) -> (Vec<String>, HashSet<String>) {
274 let mut analyzer = Self::new();
275
276 for var in outer_scope_vars {
278 analyzer.define_variable(var);
279 }
280
281 analyzer.enter_scope();
283 analyzer.function_scope_level = analyzer.scope_stack.len() - 1;
284
285 for param in &function.params {
287 for name in param.get_identifiers() {
288 analyzer.define_variable(&name);
289 }
290 }
291
292 for stmt in &function.body {
294 analyzer.analyze_statement(stmt);
295 }
296
297 (
298 analyzer.get_captured_vars(),
299 analyzer.get_mutated_captures(),
300 )
301 }
302
303 fn analyze_statement(&mut self, stmt: &shape_ast::ast::Statement) {
305 use shape_ast::ast::Statement;
306
307 match stmt {
308 Statement::Return(expr, _) => {
309 if let Some(expr) = expr {
310 self.analyze_expr(expr);
311 }
312 }
313 Statement::Expression(expr, _) => {
314 self.analyze_expr(expr);
315 }
316 Statement::VariableDecl(decl, _) => {
317 if let Some(value) = &decl.value {
319 self.analyze_expr(value);
320 }
321 if let Some(name) = decl.pattern.as_identifier() {
323 self.define_variable(name);
324 } else {
325 for name in decl.pattern.get_identifiers() {
327 self.define_variable(&name);
328 }
329 }
330 }
331 Statement::Assignment(assign, _) => {
332 self.analyze_expr(&assign.value);
333 if let Some(name) = assign.pattern.as_identifier() {
334 self.mark_capture_mutated(name);
336 self.check_variable_reference(name);
337 } else {
338 for name in assign.pattern.get_identifiers() {
340 self.mark_capture_mutated(&name);
341 self.check_variable_reference(&name);
342 }
343 }
344 }
345 Statement::If(if_stmt, _) => {
346 self.analyze_expr(&if_stmt.condition);
347 self.enter_scope();
348 for stmt in &if_stmt.then_body {
349 self.analyze_statement(stmt);
350 }
351 self.exit_scope();
352
353 if let Some(else_body) = &if_stmt.else_body {
354 self.enter_scope();
355 for stmt in else_body {
356 self.analyze_statement(stmt);
357 }
358 self.exit_scope();
359 }
360 }
361 Statement::While(while_loop, _) => {
362 self.analyze_expr(&while_loop.condition);
363 self.enter_scope();
364 for stmt in &while_loop.body {
365 self.analyze_statement(stmt);
366 }
367 self.exit_scope();
368 }
369 Statement::For(for_loop, _) => {
370 self.enter_scope();
371
372 match &for_loop.init {
374 shape_ast::ast::ForInit::ForIn { pattern, iter } => {
375 self.analyze_expr(iter);
376 for name in pattern.get_identifiers() {
378 self.define_variable(&name);
379 }
380 }
381 shape_ast::ast::ForInit::ForC {
382 init: _,
383 condition,
384 update,
385 } => {
386 self.analyze_expr(condition);
389 self.analyze_expr(update);
390 }
391 }
392
393 for stmt in &for_loop.body {
394 self.analyze_statement(stmt);
395 }
396
397 self.exit_scope();
398 }
399 Statement::Break(_) | Statement::Continue(_) => {
400 }
402 Statement::Extend(ext, _) => {
403 for method in &ext.methods {
404 self.enter_scope();
405 self.define_variable("self");
406 for param in &method.params {
407 for name in param.get_identifiers() {
408 self.define_variable(&name);
409 }
410 }
411 for stmt in &method.body {
412 self.analyze_statement(stmt);
413 }
414 self.exit_scope();
415 }
416 }
417 Statement::RemoveTarget(_) => {}
418 Statement::SetParamType { .. }
419 | Statement::SetReturnType { .. }
420 | Statement::SetReturnExpr { .. } => {}
421 Statement::SetParamValue { expression, .. } => {
422 self.analyze_expr(expression);
423 }
424 Statement::ReplaceModuleExpr { expression, .. } => {
425 self.analyze_expr(expression);
426 }
427 Statement::ReplaceBodyExpr { expression, .. } => {
428 self.analyze_expr(expression);
429 }
430 Statement::ReplaceBody { body, .. } => {
431 for stmt in body {
432 self.analyze_statement(stmt);
433 }
434 }
435 }
436 }
437
438 fn analyze_expr(&mut self, expr: &shape_ast::ast::Expr) {
440 use shape_ast::ast::Expr;
441
442 match expr {
443 Expr::Identifier(name, _) => {
444 self.check_variable_reference(name);
445 }
446 Expr::Literal(..)
447 | Expr::DataRef(..)
448 | Expr::DataDateTimeRef(..)
449 | Expr::TimeRef(..)
450 | Expr::PatternRef(..) => {
451 }
453 Expr::DataRelativeAccess {
454 reference,
455 index: _,
456 ..
457 } => {
458 self.analyze_expr(reference);
459 }
461 Expr::BinaryOp { left, right, .. } => {
462 self.analyze_expr(left);
463 self.analyze_expr(right);
464 }
465 Expr::FuzzyComparison { left, right, .. } => {
466 self.analyze_expr(left);
467 self.analyze_expr(right);
468 }
469 Expr::UnaryOp { operand, .. } => {
470 self.analyze_expr(operand);
471 }
472 Expr::FunctionCall { name, args, .. } => {
473 self.check_variable_reference(name);
476 for arg in args {
477 self.analyze_expr(arg);
478 }
479 }
480 Expr::QualifiedFunctionCall {
481 namespace,
482 args,
483 ..
484 } => {
485 self.check_variable_reference(namespace);
486 for arg in args {
487 self.analyze_expr(arg);
488 }
489 }
490 Expr::EnumConstructor { payload, .. } => {
491 use shape_ast::ast::EnumConstructorPayload;
492 match payload {
493 EnumConstructorPayload::Unit => {}
494 EnumConstructorPayload::Tuple(values) => {
495 for value in values {
496 self.analyze_expr(value);
497 }
498 }
499 EnumConstructorPayload::Struct(fields) => {
500 for (_, value) in fields {
501 self.analyze_expr(value);
502 }
503 }
504 }
505 }
506 Expr::PropertyAccess { object, .. } => {
507 self.analyze_expr(object);
508 }
509 Expr::Conditional {
510 condition,
511 then_expr,
512 else_expr,
513 ..
514 } => {
515 self.analyze_expr(condition);
516 self.analyze_expr(then_expr);
517 if let Some(else_e) = else_expr {
518 self.analyze_expr(else_e);
519 }
520 }
521 Expr::Array(elements, _) => {
522 for elem in elements {
523 self.analyze_expr(elem);
524 }
525 }
526 Expr::TableRows(rows, _) => {
527 for row in rows {
528 for elem in row {
529 self.analyze_expr(elem);
530 }
531 }
532 }
533 Expr::ListComprehension(comp, _) => {
534 self.enter_scope();
536
537 for clause in &comp.clauses {
539 for name in clause.pattern.get_identifiers() {
541 self.define_variable(&name);
542 }
543
544 self.analyze_expr(&clause.iterable);
546
547 if let Some(filter) = &clause.filter {
549 self.analyze_expr(filter);
550 }
551 }
552
553 self.analyze_expr(&comp.element);
555
556 self.exit_scope();
557 }
558 Expr::Object(entries, _) => {
559 use shape_ast::ast::ObjectEntry;
560 for entry in entries {
561 match entry {
562 ObjectEntry::Field { value, .. } => self.analyze_expr(value),
563 ObjectEntry::Spread(spread_expr) => self.analyze_expr(spread_expr),
564 }
565 }
566 }
567 Expr::IndexAccess {
568 object,
569 index,
570 end_index,
571 ..
572 } => {
573 self.analyze_expr(object);
574 self.analyze_expr(index);
575 if let Some(end) = end_index {
576 self.analyze_expr(end);
577 }
578 }
579 Expr::Block(block, _) => {
580 self.enter_scope();
581 for item in &block.items {
582 match item {
583 shape_ast::ast::BlockItem::VariableDecl(decl) => {
584 if let Some(value) = &decl.value {
585 self.analyze_expr(value);
586 }
587 if let Some(name) = decl.pattern.as_identifier() {
588 self.define_variable(name);
589 }
590 }
591 shape_ast::ast::BlockItem::Assignment(assign) => {
592 self.analyze_expr(&assign.value);
593 if let Some(name) = assign.pattern.as_identifier() {
594 self.mark_capture_mutated(name);
595 self.check_variable_reference(name);
596 } else {
597 for name in assign.pattern.get_identifiers() {
598 self.mark_capture_mutated(&name);
599 self.check_variable_reference(&name);
600 }
601 }
602 }
603 shape_ast::ast::BlockItem::Statement(stmt) => {
604 self.analyze_statement(stmt);
605 }
606 shape_ast::ast::BlockItem::Expression(expr) => {
607 self.analyze_expr(expr);
608 }
609 }
610 }
611 self.exit_scope();
612 }
613 Expr::TypeAssertion { expr, .. } => {
614 self.analyze_expr(expr);
615 }
616 Expr::InstanceOf { expr, .. } => {
617 self.analyze_expr(expr);
618 }
619 Expr::FunctionExpr {
620 params,
621 return_type: _,
622 body,
623 ..
624 } => {
625 let saved_function_scope_level = self.function_scope_level;
627 self.enter_scope();
628 self.function_scope_level = self.scope_stack.len() - 1;
629
630 for param in params {
631 for name in param.get_identifiers() {
632 self.define_variable(&name);
633 }
634 }
635
636 for stmt in body {
637 self.analyze_statement(stmt);
638 }
639
640 self.exit_scope();
641 self.function_scope_level = saved_function_scope_level;
642
643 self.captured_vars
647 .retain(|_, level| *level < saved_function_scope_level);
648 self.mutated_captures
649 .retain(|name| self.captured_vars.contains_key(name));
650 }
651 Expr::Duration(..) => {
652 }
654
655 Expr::If(if_expr, _) => {
657 self.analyze_expr(&if_expr.condition);
658 self.analyze_expr(&if_expr.then_branch);
659 if let Some(else_branch) = &if_expr.else_branch {
660 self.analyze_expr(else_branch);
661 }
662 }
663
664 Expr::While(while_expr, _) => {
665 self.analyze_expr(&while_expr.condition);
666 self.analyze_expr(&while_expr.body);
667 }
668
669 Expr::For(for_expr, _) => {
670 self.enter_scope();
671 self.analyze_pattern(&for_expr.pattern);
673 self.analyze_expr(&for_expr.iterable);
674 self.analyze_expr(&for_expr.body);
675 self.exit_scope();
676 }
677
678 Expr::Loop(loop_expr, _) => {
679 self.analyze_expr(&loop_expr.body);
680 }
681
682 Expr::Let(let_expr, _) => {
683 if let Some(value) = &let_expr.value {
684 self.analyze_expr(value);
685 }
686 self.enter_scope();
687 self.analyze_pattern(&let_expr.pattern);
688 self.analyze_expr(&let_expr.body);
689 self.exit_scope();
690 }
691
692 Expr::Assign(assign, _) => {
693 self.analyze_expr(&assign.value);
694 self.analyze_expr(&assign.target);
695 }
696
697 Expr::Break(value, _) => {
698 if let Some(val) = value {
699 self.analyze_expr(val);
700 }
701 }
702
703 Expr::Continue(_) => {
704 }
706
707 Expr::Return(value, _) => {
708 if let Some(val) = value {
709 self.analyze_expr(val);
710 }
711 }
712
713 Expr::MethodCall { receiver, args, .. } => {
714 self.analyze_expr(receiver);
715 for arg in args {
716 self.analyze_expr(arg);
717 }
718 }
719
720 Expr::Match(match_expr, _) => {
721 self.analyze_expr(&match_expr.scrutinee);
722 for arm in &match_expr.arms {
723 self.enter_scope();
724 self.analyze_pattern(&arm.pattern);
725 if let Some(guard) = &arm.guard {
726 self.analyze_expr(guard);
727 }
728 self.analyze_expr(&arm.body);
729 self.exit_scope();
730 }
731 }
732
733 Expr::Unit(_) => {
734 }
736
737 Expr::Spread(inner_expr, _) => {
738 self.analyze_expr(inner_expr);
740 }
741
742 Expr::DateTime(..) => {
743 }
745 Expr::Range { start, end, .. } => {
746 if let Some(s) = start {
748 self.analyze_expr(s);
749 }
750 if let Some(e) = end {
751 self.analyze_expr(e);
752 }
753 }
754
755 Expr::TimeframeContext { expr, .. } => {
756 self.analyze_expr(expr);
758 }
759
760 Expr::TryOperator(inner, _) => {
761 self.analyze_expr(inner);
763 }
764 Expr::UsingImpl { expr, .. } => {
765 self.analyze_expr(expr);
766 }
767
768 Expr::Await(inner, _) => {
769 self.analyze_expr(inner);
771 }
772
773 Expr::SimulationCall { params, .. } => {
774 for (_, value_expr) in params {
777 self.analyze_expr(value_expr);
778 }
779 }
780
781 Expr::WindowExpr(_, _) => {
782 }
784
785 Expr::FromQuery(from_query, _) => {
786 self.analyze_expr(&from_query.source);
788 for clause in &from_query.clauses {
789 match clause {
790 shape_ast::ast::QueryClause::Where(pred) => {
791 self.analyze_expr(pred);
792 }
793 shape_ast::ast::QueryClause::OrderBy(specs) => {
794 for spec in specs {
795 self.analyze_expr(&spec.key);
796 }
797 }
798 shape_ast::ast::QueryClause::GroupBy { element, key, .. } => {
799 self.analyze_expr(element);
800 self.analyze_expr(key);
801 }
802 shape_ast::ast::QueryClause::Join {
803 source,
804 left_key,
805 right_key,
806 ..
807 } => {
808 self.analyze_expr(source);
809 self.analyze_expr(left_key);
810 self.analyze_expr(right_key);
811 }
812 shape_ast::ast::QueryClause::Let { value, .. } => {
813 self.analyze_expr(value);
814 }
815 }
816 }
817 self.analyze_expr(&from_query.select);
818 }
819 Expr::StructLiteral { fields, .. } => {
820 for (_, value_expr) in fields {
821 self.analyze_expr(value_expr);
822 }
823 }
824 Expr::Join(join_expr, _) => {
825 for branch in &join_expr.branches {
826 self.analyze_expr(&branch.expr);
827 }
828 }
829 Expr::Annotated { target, .. } => {
830 self.analyze_expr(target);
831 }
832 Expr::AsyncLet(async_let, _) => {
833 self.analyze_expr(&async_let.expr);
834 }
835 Expr::AsyncScope(inner, _) => {
836 self.analyze_expr(inner);
837 }
838 Expr::Comptime(stmts, _) => {
839 for stmt in stmts {
840 self.analyze_statement(stmt);
841 }
842 }
843 Expr::ComptimeFor(cf, _) => {
844 self.analyze_expr(&cf.iterable);
845 for stmt in &cf.body {
846 self.analyze_statement(stmt);
847 }
848 }
849 Expr::Reference { expr: inner, .. } => {
850 self.analyze_expr(inner);
851 }
852 }
853 }
854
855 fn analyze_pattern(&mut self, pattern: &shape_ast::ast::Pattern) {
857 use shape_ast::ast::Pattern;
858
859 match pattern {
860 Pattern::Identifier(name) => {
861 self.define_variable(name);
862 }
863 Pattern::Typed { name, .. } => {
864 self.define_variable(name);
865 }
866 Pattern::Wildcard | Pattern::Literal(_) => {
867 }
869 Pattern::Array(patterns) => {
870 for p in patterns {
871 self.analyze_pattern(p);
872 }
873 }
874 Pattern::Object(fields) => {
875 for (_, p) in fields {
876 self.analyze_pattern(p);
877 }
878 }
879 Pattern::Constructor { fields, .. } => match fields {
880 shape_ast::ast::PatternConstructorFields::Unit => {}
881 shape_ast::ast::PatternConstructorFields::Tuple(patterns) => {
882 for p in patterns {
883 self.analyze_pattern(p);
884 }
885 }
886 shape_ast::ast::PatternConstructorFields::Struct(fields) => {
887 for (_, p) in fields {
888 self.analyze_pattern(p);
889 }
890 }
891 },
892 }
893 }
894}