1use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use shape_ast::ast::{FunctionDef, VarKind};
7use shape_value::ValueWord;
8
9#[derive(Debug, Clone)]
11pub struct Closure {
12 pub function: Arc<FunctionDef>,
14 pub captured_env: CapturedEnvironment,
16}
17
18impl PartialEq for Closure {
19 fn eq(&self, other: &Self) -> bool {
20 Arc::ptr_eq(&self.function, &other.function) && self.captured_env == other.captured_env
23 }
24}
25
26#[derive(Debug, Clone)]
28pub struct CapturedEnvironment {
29 pub bindings: HashMap<String, CapturedBinding>,
31 pub parent: Option<Box<CapturedEnvironment>>,
33}
34
35#[derive(Debug, Clone)]
37pub struct CapturedBinding {
38 pub value: ValueWord,
40 pub kind: VarKind,
42 pub is_mutable: bool,
44}
45
46impl PartialEq for CapturedBinding {
47 fn eq(&self, other: &Self) -> bool {
48 self.kind == other.kind
49 && self.is_mutable == other.is_mutable
50 && self.value.vw_equals(&other.value)
51 }
52}
53
54impl PartialEq for CapturedEnvironment {
55 fn eq(&self, other: &Self) -> bool {
56 self.bindings == other.bindings && self.parent == other.parent
57 }
58}
59
60impl Default for CapturedEnvironment {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl CapturedEnvironment {
67 pub fn new() -> Self {
69 Self {
70 bindings: HashMap::new(),
71 parent: None,
72 }
73 }
74
75 pub fn with_parent(parent: CapturedEnvironment) -> Self {
77 Self {
78 bindings: HashMap::new(),
79 parent: Some(Box::new(parent)),
80 }
81 }
82
83 pub fn capture(&mut self, name: String, value: ValueWord, kind: VarKind) {
85 let is_mutable = matches!(kind, VarKind::Var);
86 self.bindings.insert(
87 name,
88 CapturedBinding {
89 value,
90 kind,
91 is_mutable,
92 },
93 );
94 }
95
96 pub fn lookup(&self, name: &str) -> Option<&CapturedBinding> {
98 self.bindings
99 .get(name)
100 .or_else(|| self.parent.as_ref().and_then(|p| p.lookup(name)))
101 }
102
103 pub fn lookup_mut(&mut self, name: &str) -> Option<&mut CapturedBinding> {
105 if self.bindings.contains_key(name) {
106 self.bindings.get_mut(name)
107 } else if let Some(parent) = &mut self.parent {
108 parent.lookup_mut(name)
109 } else {
110 None
111 }
112 }
113
114 pub fn all_captured_names(&self) -> Vec<String> {
116 let mut names: Vec<String> = self.bindings.keys().cloned().collect();
117
118 if let Some(parent) = &self.parent {
119 for name in parent.all_captured_names() {
120 if !names.contains(&name) {
121 names.push(name);
122 }
123 }
124 }
125
126 names
127 }
128}
129
130pub struct EnvironmentAnalyzer {
132 scope_stack: Vec<HashMap<String, bool>>, captured_vars: HashMap<String, usize>, mutated_captures: HashSet<String>,
138 function_scope_level: usize,
143}
144
145impl Default for EnvironmentAnalyzer {
146 fn default() -> Self {
147 Self {
148 scope_stack: vec![HashMap::new()],
149 captured_vars: HashMap::new(),
150 mutated_captures: HashSet::new(),
151 function_scope_level: 1,
152 }
153 }
154}
155
156impl EnvironmentAnalyzer {
157 pub fn new() -> Self {
158 Self {
159 scope_stack: vec![HashMap::new()], captured_vars: HashMap::new(),
161 mutated_captures: HashSet::new(),
162 function_scope_level: 1, }
164 }
165
166 pub fn enter_scope(&mut self) {
168 self.scope_stack.push(HashMap::new());
169 }
170
171 pub fn exit_scope(&mut self) {
173 self.scope_stack.pop();
174 }
175
176 pub fn define_variable(&mut self, name: &str) {
178 if let Some(current_scope) = self.scope_stack.last_mut() {
179 current_scope.insert(name.to_string(), true);
180 }
181 }
182
183 pub fn check_variable_reference(&mut self, name: &str) {
189 for (level, scope) in self.scope_stack.iter().enumerate().rev() {
191 if scope.contains_key(name) {
192 if level < self.function_scope_level {
195 self.captured_vars.insert(name.to_string(), level);
196 }
197 return;
198 }
199 }
200 }
201
202 pub fn mark_capture_mutated(&mut self, name: &str) {
204 for (level, scope) in self.scope_stack.iter().enumerate().rev() {
206 if scope.contains_key(name) {
207 if level < self.function_scope_level {
208 self.captured_vars.insert(name.to_string(), level);
209 self.mutated_captures.insert(name.to_string());
210 }
211 return;
212 }
213 }
214 }
215
216 pub fn get_captured_vars(&self) -> Vec<String> {
218 self.captured_vars.keys().cloned().collect()
219 }
220
221 pub fn get_mutated_captures(&self) -> HashSet<String> {
223 self.mutated_captures.clone()
224 }
225
226 pub fn analyze_function(function: &FunctionDef, outer_scope_vars: &[String]) -> Vec<String> {
228 let mut analyzer = Self::new();
229
230 for var in outer_scope_vars {
232 analyzer.define_variable(var);
233 }
234
235 analyzer.enter_scope();
237 analyzer.function_scope_level = analyzer.scope_stack.len() - 1;
240
241 for param in &function.params {
243 for name in param.get_identifiers() {
244 analyzer.define_variable(&name);
245 }
246 }
247
248 for stmt in &function.body {
250 analyzer.analyze_statement(stmt);
251 }
252
253 analyzer.get_captured_vars()
254 }
255
256 pub fn analyze_function_with_mutability(
259 function: &FunctionDef,
260 outer_scope_vars: &[String],
261 ) -> (Vec<String>, HashSet<String>) {
262 let mut analyzer = Self::new();
263
264 for var in outer_scope_vars {
266 analyzer.define_variable(var);
267 }
268
269 analyzer.enter_scope();
271 analyzer.function_scope_level = analyzer.scope_stack.len() - 1;
272
273 for param in &function.params {
275 for name in param.get_identifiers() {
276 analyzer.define_variable(&name);
277 }
278 }
279
280 for stmt in &function.body {
282 analyzer.analyze_statement(stmt);
283 }
284
285 (
286 analyzer.get_captured_vars(),
287 analyzer.get_mutated_captures(),
288 )
289 }
290
291 fn analyze_statement(&mut self, stmt: &shape_ast::ast::Statement) {
293 use shape_ast::ast::Statement;
294
295 match stmt {
296 Statement::Return(expr, _) => {
297 if let Some(expr) = expr {
298 self.analyze_expr(expr);
299 }
300 }
301 Statement::Expression(expr, _) => {
302 self.analyze_expr(expr);
303 }
304 Statement::VariableDecl(decl, _) => {
305 if let Some(value) = &decl.value {
307 self.analyze_expr(value);
308 }
309 if let Some(name) = decl.pattern.as_identifier() {
311 self.define_variable(name);
312 } else {
313 for name in decl.pattern.get_identifiers() {
315 self.define_variable(&name);
316 }
317 }
318 }
319 Statement::Assignment(assign, _) => {
320 self.analyze_expr(&assign.value);
321 if let Some(name) = assign.pattern.as_identifier() {
322 self.mark_capture_mutated(name);
324 self.check_variable_reference(name);
325 } else {
326 for name in assign.pattern.get_identifiers() {
328 self.mark_capture_mutated(&name);
329 self.check_variable_reference(&name);
330 }
331 }
332 }
333 Statement::If(if_stmt, _) => {
334 self.analyze_expr(&if_stmt.condition);
335 self.enter_scope();
336 for stmt in &if_stmt.then_body {
337 self.analyze_statement(stmt);
338 }
339 self.exit_scope();
340
341 if let Some(else_body) = &if_stmt.else_body {
342 self.enter_scope();
343 for stmt in else_body {
344 self.analyze_statement(stmt);
345 }
346 self.exit_scope();
347 }
348 }
349 Statement::While(while_loop, _) => {
350 self.analyze_expr(&while_loop.condition);
351 self.enter_scope();
352 for stmt in &while_loop.body {
353 self.analyze_statement(stmt);
354 }
355 self.exit_scope();
356 }
357 Statement::For(for_loop, _) => {
358 self.enter_scope();
359
360 match &for_loop.init {
362 shape_ast::ast::ForInit::ForIn { pattern, iter } => {
363 self.analyze_expr(iter);
364 for name in pattern.get_identifiers() {
366 self.define_variable(&name);
367 }
368 }
369 shape_ast::ast::ForInit::ForC {
370 init: _,
371 condition,
372 update,
373 } => {
374 self.analyze_expr(condition);
377 self.analyze_expr(update);
378 }
379 }
380
381 for stmt in &for_loop.body {
382 self.analyze_statement(stmt);
383 }
384
385 self.exit_scope();
386 }
387 Statement::Break(_) | Statement::Continue(_) => {
388 }
390 Statement::Extend(ext, _) => {
391 for method in &ext.methods {
392 self.enter_scope();
393 self.define_variable("self");
394 for param in &method.params {
395 for name in param.get_identifiers() {
396 self.define_variable(&name);
397 }
398 }
399 for stmt in &method.body {
400 self.analyze_statement(stmt);
401 }
402 self.exit_scope();
403 }
404 }
405 Statement::RemoveTarget(_) => {}
406 Statement::SetParamType { .. }
407 | Statement::SetReturnType { .. }
408 | Statement::SetReturnExpr { .. } => {}
409 Statement::SetParamValue { expression, .. } => {
410 self.analyze_expr(expression);
411 }
412 Statement::ReplaceModuleExpr { expression, .. } => {
413 self.analyze_expr(expression);
414 }
415 Statement::ReplaceBodyExpr { expression, .. } => {
416 self.analyze_expr(expression);
417 }
418 Statement::ReplaceBody { body, .. } => {
419 for stmt in body {
420 self.analyze_statement(stmt);
421 }
422 }
423 }
424 }
425
426 fn analyze_expr(&mut self, expr: &shape_ast::ast::Expr) {
428 use shape_ast::ast::Expr;
429
430 match expr {
431 Expr::Identifier(name, _) => {
432 self.check_variable_reference(name);
433 }
434 Expr::Literal(..)
435 | Expr::DataRef(..)
436 | Expr::DataDateTimeRef(..)
437 | Expr::TimeRef(..)
438 | Expr::PatternRef(..) => {
439 }
441 Expr::DataRelativeAccess {
442 reference,
443 index: _,
444 ..
445 } => {
446 self.analyze_expr(reference);
447 }
449 Expr::BinaryOp { left, right, .. } => {
450 self.analyze_expr(left);
451 self.analyze_expr(right);
452 }
453 Expr::FuzzyComparison { left, right, .. } => {
454 self.analyze_expr(left);
455 self.analyze_expr(right);
456 }
457 Expr::UnaryOp { operand, .. } => {
458 self.analyze_expr(operand);
459 }
460 Expr::FunctionCall { name, args, .. } => {
461 self.check_variable_reference(name);
464 for arg in args {
465 self.analyze_expr(arg);
466 }
467 }
468 Expr::EnumConstructor { payload, .. } => {
469 use shape_ast::ast::EnumConstructorPayload;
470 match payload {
471 EnumConstructorPayload::Unit => {}
472 EnumConstructorPayload::Tuple(values) => {
473 for value in values {
474 self.analyze_expr(value);
475 }
476 }
477 EnumConstructorPayload::Struct(fields) => {
478 for (_, value) in fields {
479 self.analyze_expr(value);
480 }
481 }
482 }
483 }
484 Expr::PropertyAccess { object, .. } => {
485 self.analyze_expr(object);
486 }
487 Expr::Conditional {
488 condition,
489 then_expr,
490 else_expr,
491 ..
492 } => {
493 self.analyze_expr(condition);
494 self.analyze_expr(then_expr);
495 if let Some(else_e) = else_expr {
496 self.analyze_expr(else_e);
497 }
498 }
499 Expr::Array(elements, _) => {
500 for elem in elements {
501 self.analyze_expr(elem);
502 }
503 }
504 Expr::TableRows(rows, _) => {
505 for row in rows {
506 for elem in row {
507 self.analyze_expr(elem);
508 }
509 }
510 }
511 Expr::ListComprehension(comp, _) => {
512 self.enter_scope();
514
515 for clause in &comp.clauses {
517 for name in clause.pattern.get_identifiers() {
519 self.define_variable(&name);
520 }
521
522 self.analyze_expr(&clause.iterable);
524
525 if let Some(filter) = &clause.filter {
527 self.analyze_expr(filter);
528 }
529 }
530
531 self.analyze_expr(&comp.element);
533
534 self.exit_scope();
535 }
536 Expr::Object(entries, _) => {
537 use shape_ast::ast::ObjectEntry;
538 for entry in entries {
539 match entry {
540 ObjectEntry::Field { value, .. } => self.analyze_expr(value),
541 ObjectEntry::Spread(spread_expr) => self.analyze_expr(spread_expr),
542 }
543 }
544 }
545 Expr::IndexAccess {
546 object,
547 index,
548 end_index,
549 ..
550 } => {
551 self.analyze_expr(object);
552 self.analyze_expr(index);
553 if let Some(end) = end_index {
554 self.analyze_expr(end);
555 }
556 }
557 Expr::Block(block, _) => {
558 self.enter_scope();
559 for item in &block.items {
560 match item {
561 shape_ast::ast::BlockItem::VariableDecl(decl) => {
562 if let Some(value) = &decl.value {
563 self.analyze_expr(value);
564 }
565 if let Some(name) = decl.pattern.as_identifier() {
566 self.define_variable(name);
567 }
568 }
569 shape_ast::ast::BlockItem::Assignment(assign) => {
570 self.analyze_expr(&assign.value);
571 if let Some(name) = assign.pattern.as_identifier() {
572 self.mark_capture_mutated(name);
573 self.check_variable_reference(name);
574 } else {
575 for name in assign.pattern.get_identifiers() {
576 self.mark_capture_mutated(&name);
577 self.check_variable_reference(&name);
578 }
579 }
580 }
581 shape_ast::ast::BlockItem::Statement(stmt) => {
582 self.analyze_statement(stmt);
583 }
584 shape_ast::ast::BlockItem::Expression(expr) => {
585 self.analyze_expr(expr);
586 }
587 }
588 }
589 self.exit_scope();
590 }
591 Expr::TypeAssertion { expr, .. } => {
592 self.analyze_expr(expr);
593 }
594 Expr::InstanceOf { expr, .. } => {
595 self.analyze_expr(expr);
596 }
597 Expr::FunctionExpr {
598 params,
599 return_type: _,
600 body,
601 ..
602 } => {
603 let saved_function_scope_level = self.function_scope_level;
605 self.enter_scope();
606 self.function_scope_level = self.scope_stack.len() - 1;
607
608 for param in params {
609 for name in param.get_identifiers() {
610 self.define_variable(&name);
611 }
612 }
613
614 for stmt in body {
615 self.analyze_statement(stmt);
616 }
617
618 self.exit_scope();
619 self.function_scope_level = saved_function_scope_level;
620
621 self.captured_vars
625 .retain(|_, level| *level < saved_function_scope_level);
626 self.mutated_captures
627 .retain(|name| self.captured_vars.contains_key(name));
628 }
629 Expr::Duration(..) => {
630 }
632
633 Expr::If(if_expr, _) => {
635 self.analyze_expr(&if_expr.condition);
636 self.analyze_expr(&if_expr.then_branch);
637 if let Some(else_branch) = &if_expr.else_branch {
638 self.analyze_expr(else_branch);
639 }
640 }
641
642 Expr::While(while_expr, _) => {
643 self.analyze_expr(&while_expr.condition);
644 self.analyze_expr(&while_expr.body);
645 }
646
647 Expr::For(for_expr, _) => {
648 self.enter_scope();
649 self.analyze_pattern(&for_expr.pattern);
651 self.analyze_expr(&for_expr.iterable);
652 self.analyze_expr(&for_expr.body);
653 self.exit_scope();
654 }
655
656 Expr::Loop(loop_expr, _) => {
657 self.analyze_expr(&loop_expr.body);
658 }
659
660 Expr::Let(let_expr, _) => {
661 if let Some(value) = &let_expr.value {
662 self.analyze_expr(value);
663 }
664 self.enter_scope();
665 self.analyze_pattern(&let_expr.pattern);
666 self.analyze_expr(&let_expr.body);
667 self.exit_scope();
668 }
669
670 Expr::Assign(assign, _) => {
671 self.analyze_expr(&assign.value);
672 self.analyze_expr(&assign.target);
673 }
674
675 Expr::Break(value, _) => {
676 if let Some(val) = value {
677 self.analyze_expr(val);
678 }
679 }
680
681 Expr::Continue(_) => {
682 }
684
685 Expr::Return(value, _) => {
686 if let Some(val) = value {
687 self.analyze_expr(val);
688 }
689 }
690
691 Expr::MethodCall { receiver, args, .. } => {
692 self.analyze_expr(receiver);
693 for arg in args {
694 self.analyze_expr(arg);
695 }
696 }
697
698 Expr::Match(match_expr, _) => {
699 self.analyze_expr(&match_expr.scrutinee);
700 for arm in &match_expr.arms {
701 self.enter_scope();
702 self.analyze_pattern(&arm.pattern);
703 if let Some(guard) = &arm.guard {
704 self.analyze_expr(guard);
705 }
706 self.analyze_expr(&arm.body);
707 self.exit_scope();
708 }
709 }
710
711 Expr::Unit(_) => {
712 }
714
715 Expr::Spread(inner_expr, _) => {
716 self.analyze_expr(inner_expr);
718 }
719
720 Expr::DateTime(..) => {
721 }
723 Expr::Range { start, end, .. } => {
724 if let Some(s) = start {
726 self.analyze_expr(s);
727 }
728 if let Some(e) = end {
729 self.analyze_expr(e);
730 }
731 }
732
733 Expr::TimeframeContext { expr, .. } => {
734 self.analyze_expr(expr);
736 }
737
738 Expr::TryOperator(inner, _) => {
739 self.analyze_expr(inner);
741 }
742 Expr::UsingImpl { expr, .. } => {
743 self.analyze_expr(expr);
744 }
745
746 Expr::Await(inner, _) => {
747 self.analyze_expr(inner);
749 }
750
751 Expr::SimulationCall { params, .. } => {
752 for (_, value_expr) in params {
755 self.analyze_expr(value_expr);
756 }
757 }
758
759 Expr::WindowExpr(_, _) => {
760 }
762
763 Expr::FromQuery(from_query, _) => {
764 self.analyze_expr(&from_query.source);
766 for clause in &from_query.clauses {
767 match clause {
768 shape_ast::ast::QueryClause::Where(pred) => {
769 self.analyze_expr(pred);
770 }
771 shape_ast::ast::QueryClause::OrderBy(specs) => {
772 for spec in specs {
773 self.analyze_expr(&spec.key);
774 }
775 }
776 shape_ast::ast::QueryClause::GroupBy { element, key, .. } => {
777 self.analyze_expr(element);
778 self.analyze_expr(key);
779 }
780 shape_ast::ast::QueryClause::Join {
781 source,
782 left_key,
783 right_key,
784 ..
785 } => {
786 self.analyze_expr(source);
787 self.analyze_expr(left_key);
788 self.analyze_expr(right_key);
789 }
790 shape_ast::ast::QueryClause::Let { value, .. } => {
791 self.analyze_expr(value);
792 }
793 }
794 }
795 self.analyze_expr(&from_query.select);
796 }
797 Expr::StructLiteral { fields, .. } => {
798 for (_, value_expr) in fields {
799 self.analyze_expr(value_expr);
800 }
801 }
802 Expr::Join(join_expr, _) => {
803 for branch in &join_expr.branches {
804 self.analyze_expr(&branch.expr);
805 }
806 }
807 Expr::Annotated { target, .. } => {
808 self.analyze_expr(target);
809 }
810 Expr::AsyncLet(async_let, _) => {
811 self.analyze_expr(&async_let.expr);
812 }
813 Expr::AsyncScope(inner, _) => {
814 self.analyze_expr(inner);
815 }
816 Expr::Comptime(stmts, _) => {
817 for stmt in stmts {
818 self.analyze_statement(stmt);
819 }
820 }
821 Expr::ComptimeFor(cf, _) => {
822 self.analyze_expr(&cf.iterable);
823 for stmt in &cf.body {
824 self.analyze_statement(stmt);
825 }
826 }
827 Expr::Reference { expr: inner, .. } => {
828 self.analyze_expr(inner);
829 }
830 }
831 }
832
833 fn analyze_pattern(&mut self, pattern: &shape_ast::ast::Pattern) {
835 use shape_ast::ast::Pattern;
836
837 match pattern {
838 Pattern::Identifier(name) => {
839 self.define_variable(name);
840 }
841 Pattern::Typed { name, .. } => {
842 self.define_variable(name);
843 }
844 Pattern::Wildcard | Pattern::Literal(_) => {
845 }
847 Pattern::Array(patterns) => {
848 for p in patterns {
849 self.analyze_pattern(p);
850 }
851 }
852 Pattern::Object(fields) => {
853 for (_, p) in fields {
854 self.analyze_pattern(p);
855 }
856 }
857 Pattern::Constructor { fields, .. } => match fields {
858 shape_ast::ast::PatternConstructorFields::Unit => {}
859 shape_ast::ast::PatternConstructorFields::Tuple(patterns) => {
860 for p in patterns {
861 self.analyze_pattern(p);
862 }
863 }
864 shape_ast::ast::PatternConstructorFields::Struct(fields) => {
865 for (_, p) in fields {
866 self.analyze_pattern(p);
867 }
868 }
869 },
870 }
871 }
872}