1use super::*;
2
3impl BytecodeCompiler {
4 pub(super) fn infer_reference_params_from_types(
5 program: &Program,
6 inferred_types: &HashMap<String, Type>,
7 ) -> HashMap<String, Vec<bool>> {
8 let funcs = Self::collect_program_functions(program);
9 let mut inferred = HashMap::new();
10
11 for (name, func) in funcs {
12 let mut inferred_flags = vec![false; func.params.len()];
13 let Some(Type::Function { params, .. }) = inferred_types.get(&name) else {
14 inferred.insert(name, inferred_flags);
15 continue;
16 };
17
18 for (idx, param) in func.params.iter().enumerate() {
19 if param.type_annotation.is_some()
20 || param.is_reference
21 || param.simple_name().is_none()
22 {
23 continue;
24 }
25 if let Some(inferred_param_ty) = params.get(idx)
26 && Self::type_is_heap_like(inferred_param_ty)
27 {
28 inferred_flags[idx] = true;
29 }
30 }
31 inferred.insert(name, inferred_flags);
32 }
33
34 inferred
35 }
36
37 pub(super) fn analyze_statement_for_ref_mutation(
38 stmt: &shape_ast::ast::Statement,
39 caller_name: &str,
40 param_index_by_name: &HashMap<String, usize>,
41 caller_ref_params: &[bool],
42 callee_ref_params: &HashMap<String, Vec<bool>>,
43 direct_mutates: &mut [bool],
44 edges: &mut Vec<(String, usize, String, usize)>,
45 ) {
46 use shape_ast::ast::{ForInit, Statement};
47
48 match stmt {
49 Statement::Return(Some(expr), _) | Statement::Expression(expr, _) => {
50 Self::analyze_expr_for_ref_mutation(
51 expr,
52 caller_name,
53 param_index_by_name,
54 caller_ref_params,
55 callee_ref_params,
56 direct_mutates,
57 edges,
58 );
59 }
60 Statement::VariableDecl(decl, _) => {
61 if let Some(value) = &decl.value {
62 Self::analyze_expr_for_ref_mutation(
63 value,
64 caller_name,
65 param_index_by_name,
66 caller_ref_params,
67 callee_ref_params,
68 direct_mutates,
69 edges,
70 );
71 }
72 }
73 Statement::Assignment(assign, _) => {
74 if let Some(name) = assign.pattern.as_identifier()
75 && let Some(&idx) = param_index_by_name.get(name)
76 && caller_ref_params.get(idx).copied().unwrap_or(false)
77 {
78 direct_mutates[idx] = true;
79 }
80 Self::analyze_expr_for_ref_mutation(
81 &assign.value,
82 caller_name,
83 param_index_by_name,
84 caller_ref_params,
85 callee_ref_params,
86 direct_mutates,
87 edges,
88 );
89 }
90 Statement::If(if_stmt, _) => {
91 Self::analyze_expr_for_ref_mutation(
92 &if_stmt.condition,
93 caller_name,
94 param_index_by_name,
95 caller_ref_params,
96 callee_ref_params,
97 direct_mutates,
98 edges,
99 );
100 for stmt in &if_stmt.then_body {
101 Self::analyze_statement_for_ref_mutation(
102 stmt,
103 caller_name,
104 param_index_by_name,
105 caller_ref_params,
106 callee_ref_params,
107 direct_mutates,
108 edges,
109 );
110 }
111 if let Some(else_body) = &if_stmt.else_body {
112 for stmt in else_body {
113 Self::analyze_statement_for_ref_mutation(
114 stmt,
115 caller_name,
116 param_index_by_name,
117 caller_ref_params,
118 callee_ref_params,
119 direct_mutates,
120 edges,
121 );
122 }
123 }
124 }
125 Statement::While(while_loop, _) => {
126 Self::analyze_expr_for_ref_mutation(
127 &while_loop.condition,
128 caller_name,
129 param_index_by_name,
130 caller_ref_params,
131 callee_ref_params,
132 direct_mutates,
133 edges,
134 );
135 for stmt in &while_loop.body {
136 Self::analyze_statement_for_ref_mutation(
137 stmt,
138 caller_name,
139 param_index_by_name,
140 caller_ref_params,
141 callee_ref_params,
142 direct_mutates,
143 edges,
144 );
145 }
146 }
147 Statement::For(for_loop, _) => {
148 match &for_loop.init {
149 ForInit::ForIn { iter, .. } => {
150 Self::analyze_expr_for_ref_mutation(
151 iter,
152 caller_name,
153 param_index_by_name,
154 caller_ref_params,
155 callee_ref_params,
156 direct_mutates,
157 edges,
158 );
159 }
160 ForInit::ForC {
161 init,
162 condition,
163 update,
164 } => {
165 Self::analyze_statement_for_ref_mutation(
166 init,
167 caller_name,
168 param_index_by_name,
169 caller_ref_params,
170 callee_ref_params,
171 direct_mutates,
172 edges,
173 );
174 Self::analyze_expr_for_ref_mutation(
175 condition,
176 caller_name,
177 param_index_by_name,
178 caller_ref_params,
179 callee_ref_params,
180 direct_mutates,
181 edges,
182 );
183 Self::analyze_expr_for_ref_mutation(
184 update,
185 caller_name,
186 param_index_by_name,
187 caller_ref_params,
188 callee_ref_params,
189 direct_mutates,
190 edges,
191 );
192 }
193 }
194 for stmt in &for_loop.body {
195 Self::analyze_statement_for_ref_mutation(
196 stmt,
197 caller_name,
198 param_index_by_name,
199 caller_ref_params,
200 callee_ref_params,
201 direct_mutates,
202 edges,
203 );
204 }
205 }
206 Statement::Extend(ext, _) => {
207 for method in &ext.methods {
208 for stmt in &method.body {
209 Self::analyze_statement_for_ref_mutation(
210 stmt,
211 caller_name,
212 param_index_by_name,
213 caller_ref_params,
214 callee_ref_params,
215 direct_mutates,
216 edges,
217 );
218 }
219 }
220 }
221 Statement::SetReturnExpr { expression, .. } => {
222 Self::analyze_expr_for_ref_mutation(
223 expression,
224 caller_name,
225 param_index_by_name,
226 caller_ref_params,
227 callee_ref_params,
228 direct_mutates,
229 edges,
230 );
231 }
232 Statement::ReplaceBodyExpr { expression, .. } => {
233 Self::analyze_expr_for_ref_mutation(
234 expression,
235 caller_name,
236 param_index_by_name,
237 caller_ref_params,
238 callee_ref_params,
239 direct_mutates,
240 edges,
241 );
242 }
243 Statement::ReplaceModuleExpr { expression, .. } => {
244 Self::analyze_expr_for_ref_mutation(
245 expression,
246 caller_name,
247 param_index_by_name,
248 caller_ref_params,
249 callee_ref_params,
250 direct_mutates,
251 edges,
252 );
253 }
254 Statement::ReplaceBody { body, .. } => {
255 for stmt in body {
256 Self::analyze_statement_for_ref_mutation(
257 stmt,
258 caller_name,
259 param_index_by_name,
260 caller_ref_params,
261 callee_ref_params,
262 direct_mutates,
263 edges,
264 );
265 }
266 }
267 Statement::SetParamValue { expression, .. } => {
268 Self::analyze_expr_for_ref_mutation(
269 expression,
270 caller_name,
271 param_index_by_name,
272 caller_ref_params,
273 callee_ref_params,
274 direct_mutates,
275 edges,
276 );
277 }
278 Statement::Break(_)
279 | Statement::Continue(_)
280 | Statement::Return(None, _)
281 | Statement::RemoveTarget(_)
282 | Statement::SetParamType { .. }
283 | Statement::SetReturnType { .. } => {}
284 }
285 }
286
287 pub(super) fn ref_param_index_from_arg(
288 arg: &shape_ast::ast::Expr,
289 param_index_by_name: &HashMap<String, usize>,
290 caller_ref_params: &[bool],
291 ) -> Option<usize> {
292 match arg {
293 shape_ast::ast::Expr::Reference { expr: inner, .. } => match inner.as_ref() {
294 shape_ast::ast::Expr::Identifier(name, _) => param_index_by_name
295 .get(name)
296 .copied()
297 .filter(|idx| caller_ref_params.get(*idx).copied().unwrap_or(false)),
298 _ => None,
299 },
300 shape_ast::ast::Expr::Identifier(name, _) => param_index_by_name
301 .get(name)
302 .copied()
303 .filter(|idx| caller_ref_params.get(*idx).copied().unwrap_or(false)),
304 _ => None,
305 }
306 }
307}
308
309
310impl BytecodeCompiler {
311 pub(super) fn analyze_expr_for_ref_mutation(
312 expr: &shape_ast::ast::Expr,
313 caller_name: &str,
314 param_index_by_name: &HashMap<String, usize>,
315 caller_ref_params: &[bool],
316 callee_ref_params: &HashMap<String, Vec<bool>>,
317 direct_mutates: &mut [bool],
318 edges: &mut Vec<(String, usize, String, usize)>,
319 ) {
320 use shape_ast::ast::Expr;
321 macro_rules! visit_expr {
322 ($e:expr) => {
323 Self::analyze_expr_for_ref_mutation(
324 $e,
325 caller_name,
326 param_index_by_name,
327 caller_ref_params,
328 callee_ref_params,
329 direct_mutates,
330 edges,
331 )
332 };
333 }
334 macro_rules! visit_stmt {
335 ($s:expr) => {
336 Self::analyze_statement_for_ref_mutation(
337 $s,
338 caller_name,
339 param_index_by_name,
340 caller_ref_params,
341 callee_ref_params,
342 direct_mutates,
343 edges,
344 )
345 };
346 }
347
348 match expr {
349 Expr::Assign(assign, _) => {
350 match assign.target.as_ref() {
351 Expr::Identifier(name, _) => {
352 if let Some(&idx) = param_index_by_name.get(name)
353 && caller_ref_params.get(idx).copied().unwrap_or(false)
354 {
355 direct_mutates[idx] = true;
356 }
357 }
358 Expr::IndexAccess { object, .. } | Expr::PropertyAccess { object, .. } => {
359 if let Expr::Identifier(name, _) = object.as_ref()
360 && let Some(&idx) = param_index_by_name.get(name)
361 && caller_ref_params.get(idx).copied().unwrap_or(false)
362 {
363 direct_mutates[idx] = true;
364 }
365 }
366 _ => {}
367 }
368 visit_expr!(&assign.value);
369 }
370 Expr::FunctionCall {
371 name,
372 args,
373 named_args,
374 ..
375 } => {
376 if let Some(callee_params) = callee_ref_params.get(name) {
377 for (arg_idx, arg) in args.iter().enumerate() {
378 if !callee_params.get(arg_idx).copied().unwrap_or(false) {
379 continue;
380 }
381 if let Some(caller_param_idx) = Self::ref_param_index_from_arg(
382 arg,
383 param_index_by_name,
384 caller_ref_params,
385 ) {
386 edges.push((
387 caller_name.to_string(),
388 caller_param_idx,
389 name.clone(),
390 arg_idx,
391 ));
392 }
393 }
394 }
395 for arg in args {
402 visit_expr!(arg);
403 }
404
405 for (_, arg) in named_args {
406 if let Some(idx) =
407 Self::ref_param_index_from_arg(arg, param_index_by_name, caller_ref_params)
408 {
409 direct_mutates[idx] = true;
410 }
411 visit_expr!(arg);
412 }
413 }
414 Expr::QualifiedFunctionCall {
415 namespace,
416 function,
417 args,
418 named_args,
419 ..
420 } => {
421 let scoped_name = format!("{}::{}", namespace, function);
422 if let Some(callee_params) = callee_ref_params.get(&scoped_name) {
423 for (arg_idx, arg) in args.iter().enumerate() {
424 if !callee_params.get(arg_idx).copied().unwrap_or(false) {
425 continue;
426 }
427 if let Some(caller_param_idx) = Self::ref_param_index_from_arg(
428 arg,
429 param_index_by_name,
430 caller_ref_params,
431 ) {
432 edges.push((
433 caller_name.to_string(),
434 caller_param_idx,
435 scoped_name.clone(),
436 arg_idx,
437 ));
438 }
439 }
440 }
441
442 for arg in args {
443 visit_expr!(arg);
444 }
445
446 for (_, arg) in named_args {
447 if let Some(idx) =
448 Self::ref_param_index_from_arg(arg, param_index_by_name, caller_ref_params)
449 {
450 direct_mutates[idx] = true;
451 }
452 visit_expr!(arg);
453 }
454 }
455 Expr::MethodCall {
456 receiver,
457 args,
458 named_args,
459 ..
460 } => {
461 visit_expr!(receiver);
462 for arg in args {
463 visit_expr!(arg);
464 }
465 for (_, arg) in named_args {
466 visit_expr!(arg);
467 }
468 }
469 Expr::UnaryOp { operand, .. }
470 | Expr::Spread(operand, _)
471 | Expr::TryOperator(operand, _)
472 | Expr::Await(operand, _)
473 | Expr::TimeframeContext { expr: operand, .. }
474 | Expr::UsingImpl { expr: operand, .. }
475 | Expr::Reference { expr: operand, .. } => {
476 visit_expr!(operand);
477 }
478 Expr::BinaryOp { left, right, .. } | Expr::FuzzyComparison { left, right, .. } => {
479 visit_expr!(left);
480 visit_expr!(right);
481 }
482 Expr::PropertyAccess { object, .. } => {
483 visit_expr!(object);
484 }
485 Expr::IndexAccess {
486 object,
487 index,
488 end_index,
489 ..
490 } => {
491 visit_expr!(object);
492 visit_expr!(index);
493 if let Some(end) = end_index {
494 visit_expr!(end);
495 }
496 }
497 Expr::Conditional {
498 condition,
499 then_expr,
500 else_expr,
501 ..
502 } => {
503 visit_expr!(condition);
504 visit_expr!(then_expr);
505 if let Some(else_expr) = else_expr {
506 visit_expr!(else_expr);
507 }
508 }
509 Expr::Array(items, _) => {
510 for item in items {
511 visit_expr!(item);
512 }
513 }
514 Expr::TableRows(rows, _) => {
515 for row in rows {
516 for elem in row {
517 visit_expr!(elem);
518 }
519 }
520 }
521 Expr::Object(entries, _) => {
522 for entry in entries {
523 match entry {
524 shape_ast::ast::ObjectEntry::Field { value, .. } => {
525 visit_expr!(value);
526 }
527 shape_ast::ast::ObjectEntry::Spread(spread) => {
528 visit_expr!(spread);
529 }
530 }
531 }
532 }
533 Expr::ListComprehension(comp, _) => {
534 visit_expr!(&comp.element);
535 for clause in &comp.clauses {
536 visit_expr!(&clause.iterable);
537 if let Some(filter) = &clause.filter {
538 visit_expr!(filter);
539 }
540 }
541 }
542 Expr::Block(block, _) => {
543 for item in &block.items {
544 match item {
545 shape_ast::ast::BlockItem::VariableDecl(decl) => {
546 if let Some(value) = &decl.value {
547 visit_expr!(value);
548 }
549 }
550 shape_ast::ast::BlockItem::Assignment(assign) => {
551 if let Some(name) = assign.pattern.as_identifier()
552 && let Some(&idx) = param_index_by_name.get(name)
553 && caller_ref_params.get(idx).copied().unwrap_or(false)
554 {
555 direct_mutates[idx] = true;
556 }
557 visit_expr!(&assign.value);
558 }
559 shape_ast::ast::BlockItem::Statement(stmt) => {
560 visit_stmt!(stmt);
561 }
562 shape_ast::ast::BlockItem::Expression(expr) => {
563 visit_expr!(expr);
564 }
565 }
566 }
567 }
568 Expr::FunctionExpr { body, .. } => {
569 for stmt in body {
570 visit_stmt!(stmt);
571 }
572 }
573 Expr::If(if_expr, _) => {
574 visit_expr!(&if_expr.condition);
575 visit_expr!(&if_expr.then_branch);
576 if let Some(else_branch) = &if_expr.else_branch {
577 visit_expr!(else_branch);
578 }
579 }
580 Expr::While(while_expr, _) => {
581 visit_expr!(&while_expr.condition);
582 visit_expr!(&while_expr.body);
583 }
584 Expr::For(for_expr, _) => {
585 visit_expr!(&for_expr.iterable);
586 visit_expr!(&for_expr.body);
587 }
588 Expr::Loop(loop_expr, _) => {
589 visit_expr!(&loop_expr.body);
590 }
591 Expr::Let(let_expr, _) => {
592 if let Some(value) = &let_expr.value {
593 visit_expr!(value);
594 }
595 visit_expr!(&let_expr.body);
596 }
597 Expr::Match(match_expr, _) => {
598 visit_expr!(&match_expr.scrutinee);
599 for arm in &match_expr.arms {
600 if let Some(guard) = &arm.guard {
601 visit_expr!(guard);
602 }
603 visit_expr!(&arm.body);
604 }
605 }
606 Expr::Join(join_expr, _) => {
607 for branch in &join_expr.branches {
608 visit_expr!(&branch.expr);
609 }
610 }
611 Expr::Annotated { target, .. } => {
612 visit_expr!(target);
613 }
614 Expr::AsyncLet(async_let, _) => {
615 visit_expr!(&async_let.expr);
616 }
617 Expr::AsyncScope(inner, _) => {
618 visit_expr!(inner);
619 }
620 Expr::Comptime(stmts, _) => {
621 for stmt in stmts {
622 visit_stmt!(stmt);
623 }
624 }
625 Expr::ComptimeFor(cf, _) => {
626 visit_expr!(&cf.iterable);
627 for stmt in &cf.body {
628 visit_stmt!(stmt);
629 }
630 }
631 Expr::SimulationCall { params, .. } => {
632 for (_, value) in params {
633 visit_expr!(value);
634 }
635 }
636 Expr::WindowExpr(window_expr, _) => {
637 match &window_expr.function {
638 shape_ast::ast::WindowFunction::Lag { expr, default, .. }
639 | shape_ast::ast::WindowFunction::Lead { expr, default, .. } => {
640 visit_expr!(expr);
641 if let Some(default) = default {
642 visit_expr!(default);
643 }
644 }
645 shape_ast::ast::WindowFunction::FirstValue(expr)
646 | shape_ast::ast::WindowFunction::LastValue(expr)
647 | shape_ast::ast::WindowFunction::NthValue(expr, _)
648 | shape_ast::ast::WindowFunction::Sum(expr)
649 | shape_ast::ast::WindowFunction::Avg(expr)
650 | shape_ast::ast::WindowFunction::Min(expr)
651 | shape_ast::ast::WindowFunction::Max(expr) => {
652 visit_expr!(expr);
653 }
654 shape_ast::ast::WindowFunction::Count(expr) => {
655 if let Some(expr) = expr {
656 visit_expr!(expr);
657 }
658 }
659 shape_ast::ast::WindowFunction::RowNumber
660 | shape_ast::ast::WindowFunction::Rank
661 | shape_ast::ast::WindowFunction::DenseRank
662 | shape_ast::ast::WindowFunction::Ntile(_) => {}
663 }
664
665 for partition_expr in &window_expr.over.partition_by {
666 visit_expr!(partition_expr);
667 }
668 if let Some(order_by) = &window_expr.over.order_by {
669 for (order_expr, _) in &order_by.columns {
670 visit_expr!(order_expr);
671 }
672 }
673 }
674 Expr::FromQuery(fq, _) => {
675 visit_expr!(&fq.source);
676 for clause in &fq.clauses {
677 match clause {
678 shape_ast::ast::QueryClause::Where(expr) => {
679 visit_expr!(expr);
680 }
681 shape_ast::ast::QueryClause::OrderBy(items) => {
682 for item in items {
683 visit_expr!(&item.key);
684 }
685 }
686 shape_ast::ast::QueryClause::GroupBy { element, key, .. } => {
687 visit_expr!(element);
688 visit_expr!(key);
689 }
690 shape_ast::ast::QueryClause::Let { value, .. } => {
691 visit_expr!(value);
692 }
693 shape_ast::ast::QueryClause::Join {
694 source,
695 left_key,
696 right_key,
697 ..
698 } => {
699 visit_expr!(source);
700 visit_expr!(left_key);
701 visit_expr!(right_key);
702 }
703 }
704 }
705 visit_expr!(&fq.select);
706 }
707 Expr::StructLiteral { fields, .. } => {
708 for (_, value) in fields {
709 visit_expr!(value);
710 }
711 }
712 Expr::EnumConstructor { payload, .. } => match payload {
713 shape_ast::ast::EnumConstructorPayload::Unit => {}
714 shape_ast::ast::EnumConstructorPayload::Tuple(values) => {
715 for value in values {
716 visit_expr!(value);
717 }
718 }
719 shape_ast::ast::EnumConstructorPayload::Struct(fields) => {
720 for (_, value) in fields {
721 visit_expr!(value);
722 }
723 }
724 },
725 Expr::TypeAssertion {
726 expr,
727 meta_param_overrides,
728 ..
729 } => {
730 visit_expr!(expr);
731 if let Some(overrides) = meta_param_overrides {
732 for value in overrides.values() {
733 visit_expr!(value);
734 }
735 }
736 }
737 Expr::InstanceOf { expr, .. } => {
738 visit_expr!(expr);
739 }
740 Expr::Range { start, end, .. } => {
741 if let Some(start) = start {
742 visit_expr!(start);
743 }
744 if let Some(end) = end {
745 visit_expr!(end);
746 }
747 }
748 Expr::DataRelativeAccess { reference, .. } => {
749 visit_expr!(reference);
750 }
751 Expr::Break(Some(expr), _) | Expr::Return(Some(expr), _) => {
752 visit_expr!(expr);
753 }
754 Expr::Literal(..)
755 | Expr::Identifier(..)
756 | Expr::DataRef(..)
757 | Expr::DataDateTimeRef(..)
758 | Expr::TimeRef(..)
759 | Expr::DateTime(..)
760 | Expr::PatternRef(..)
761 | Expr::Unit(..)
762 | Expr::Duration(..)
763 | Expr::Continue(..)
764 | Expr::Break(None, _)
765 | Expr::Return(None, _) => {}
766 }
767 }
768}
769
770
771impl BytecodeCompiler {
772 pub(super) fn infer_reference_model(
773 program: &Program,
774 ) -> (
775 HashMap<String, Vec<bool>>,
776 HashMap<String, Vec<bool>>,
777 HashMap<String, Vec<Option<String>>>,
778 ) {
779 let funcs = Self::collect_program_functions(program);
780 let mut inference = shape_runtime::type_system::inference::TypeInferenceEngine::new();
781 let (types, _) = inference.infer_program_best_effort(program);
782 let inferred_ref_params = Self::infer_reference_params_from_types(program, &types);
783 let inferred_param_type_hints = Self::infer_param_type_hints_from_types(program, &types);
784
785 let mut effective_ref_params: HashMap<String, Vec<bool>> = HashMap::new();
786 for (name, func) in &funcs {
787 let inferred = inferred_ref_params.get(name).cloned().unwrap_or_default();
788 let mut refs = vec![false; func.params.len()];
789 for (idx, param) in func.params.iter().enumerate() {
790 refs[idx] = param.is_reference || inferred.get(idx).copied().unwrap_or(false);
791 }
792 effective_ref_params.insert(name.clone(), refs);
793 }
794
795 let mut direct_mutates: HashMap<String, Vec<bool>> = HashMap::new();
796 let mut edges: Vec<(String, usize, String, usize)> = Vec::new();
797
798 for (name, func) in &funcs {
799 let caller_refs = effective_ref_params
800 .get(name)
801 .cloned()
802 .unwrap_or_else(|| vec![false; func.params.len()]);
803 let mut direct = vec![false; func.params.len()];
804 let mut param_index_by_name: HashMap<String, usize> = HashMap::new();
805 for (idx, param) in func.params.iter().enumerate() {
806 for param_name in param.get_identifiers() {
807 param_index_by_name.insert(param_name, idx);
808 }
809 }
810 for stmt in &func.body {
811 Self::analyze_statement_for_ref_mutation(
812 stmt,
813 name,
814 ¶m_index_by_name,
815 &caller_refs,
816 &effective_ref_params,
817 &mut direct,
818 &mut edges,
819 );
820 }
821 direct_mutates.insert(name.clone(), direct);
822 }
823
824 let mut result = direct_mutates;
825 let mut changed = true;
826 while changed {
827 changed = false;
828 for (caller, caller_idx, callee, callee_idx) in &edges {
829 let callee_mutates = result
830 .get(callee)
831 .and_then(|flags| flags.get(*callee_idx))
832 .copied()
833 .unwrap_or(false);
834 if !callee_mutates {
835 continue;
836 }
837 if let Some(caller_flags) = result.get_mut(caller)
838 && let Some(flag) = caller_flags.get_mut(*caller_idx)
839 && !*flag
840 {
841 *flag = true;
842 changed = true;
843 }
844 }
845 }
846
847 (inferred_ref_params, result, inferred_param_type_hints)
848 }
849
850 pub(super) fn inferred_type_to_hint_name(ty: &Type) -> Option<String> {
851 match ty {
852 Type::Concrete(annotation) => Some(annotation.to_type_string()),
853 Type::Generic { base, args } => {
854 let base_name = Self::inferred_type_to_hint_name(base)?;
855 if args.is_empty() {
856 return Some(base_name);
857 }
858 let mut arg_names = Vec::with_capacity(args.len());
859 for arg in args {
860 arg_names.push(Self::inferred_type_to_hint_name(arg)?);
861 }
862 Some(format!("{}<{}>", base_name, arg_names.join(", ")))
863 }
864 Type::Variable(_) | Type::Constrained { .. } | Type::Function { .. } => None,
865 }
866 }
867
868 pub(super) fn infer_param_type_hints_from_types(
869 program: &Program,
870 inferred_types: &HashMap<String, Type>,
871 ) -> HashMap<String, Vec<Option<String>>> {
872 let funcs = Self::collect_program_functions(program);
873 let mut hints = HashMap::new();
874
875 for (name, func) in funcs {
876 let mut param_hints = vec![None; func.params.len()];
877 let Some(Type::Function { params, .. }) = inferred_types.get(&name) else {
878 hints.insert(name, param_hints);
879 continue;
880 };
881
882 for (idx, param) in func.params.iter().enumerate() {
883 if param.type_annotation.is_some() || param.simple_name().is_none() {
884 continue;
885 }
886 if let Some(inferred_param_ty) = params.get(idx) {
887 param_hints[idx] = Self::inferred_type_to_hint_name(inferred_param_ty);
888 }
889 }
890
891 hints.insert(name, param_hints);
892 }
893
894 hints
895 }
896
897 pub(crate) fn resolve_compiled_annotation_name(
898 &self,
899 annotation: &shape_ast::ast::Annotation,
900 ) -> Option<String> {
901 self.resolve_compiled_annotation_name_str(&annotation.name)
902 }
903
904 pub(crate) fn resolve_compiled_annotation_name_str(&self, name: &str) -> Option<String> {
905 if self.program.compiled_annotations.contains_key(name) {
906 return Some(name.to_string());
907 }
908
909 if name.contains("::") {
910 return None;
911 }
912
913 for module_path in self.module_scope_stack.iter().rev() {
914 let scoped = Self::qualify_module_symbol(module_path, name);
915 if self.program.compiled_annotations.contains_key(&scoped) {
916 return Some(scoped);
917 }
918 }
919
920 if let Some(imported) = self.imported_annotations.get(name) {
921 let hidden_name =
922 Self::qualify_module_symbol(&imported.hidden_module_name, &imported.original_name);
923 if self.program.compiled_annotations.contains_key(&hidden_name) {
924 return Some(hidden_name);
925 }
926 }
927
928 None
929 }
930
931 pub(crate) fn lookup_compiled_annotation(
932 &self,
933 annotation: &shape_ast::ast::Annotation,
934 ) -> Option<(String, crate::bytecode::CompiledAnnotation)> {
935 let resolved_name = self.resolve_compiled_annotation_name(annotation)?;
936 let compiled = self.program.compiled_annotations.get(&resolved_name)?.clone();
937 Some((resolved_name, compiled))
938 }
939
940 pub(crate) fn annotation_matches_compiled_name(
941 &self,
942 annotation: &shape_ast::ast::Annotation,
943 compiled_name: &str,
944 ) -> bool {
945 self.resolve_compiled_annotation_name(annotation)
946 .as_deref()
947 == Some(compiled_name)
948 }
949
950 pub(crate) fn annotation_args_for_compiled_name(
951 &self,
952 annotations: &[shape_ast::ast::Annotation],
953 compiled_name: &str,
954 ) -> Vec<shape_ast::ast::Expr> {
955 annotations
956 .iter()
957 .find(|annotation| self.annotation_matches_compiled_name(annotation, compiled_name))
958 .map(|annotation| annotation.args.clone())
959 .unwrap_or_default()
960 }
961
962 pub(crate) fn is_definition_annotation_target(
963 target_kind: shape_ast::ast::functions::AnnotationTargetKind,
964 ) -> bool {
965 matches!(
966 target_kind,
967 shape_ast::ast::functions::AnnotationTargetKind::Function
968 | shape_ast::ast::functions::AnnotationTargetKind::Type
969 | shape_ast::ast::functions::AnnotationTargetKind::Module
970 )
971 }
972
973 pub(crate) fn validate_annotation_target_usage(
975 &self,
976 ann: &shape_ast::ast::Annotation,
977 target_kind: shape_ast::ast::functions::AnnotationTargetKind,
978 fallback_span: shape_ast::ast::Span,
979 ) -> Result<()> {
980 let Some((_, compiled)) = self.lookup_compiled_annotation(ann) else {
981 let span = if ann.span == shape_ast::ast::Span::DUMMY {
982 fallback_span
983 } else {
984 ann.span
985 };
986 return Err(ShapeError::SemanticError {
987 message: format!("Unknown annotation '@{}'", ann.name),
988 location: Some(self.span_to_source_location(span)),
989 });
990 };
991
992 let has_definition_lifecycle =
993 compiled.on_define_handler.is_some() || compiled.metadata_handler.is_some();
994 if has_definition_lifecycle && !Self::is_definition_annotation_target(target_kind) {
995 let target_label = format!("{:?}", target_kind).to_lowercase();
996 let span = if ann.span == shape_ast::ast::Span::DUMMY {
997 fallback_span
998 } else {
999 ann.span
1000 };
1001 return Err(ShapeError::SemanticError {
1002 message: format!(
1003 "Annotation '{}' defines definition-time lifecycle hooks (`on_define`/`metadata`) and cannot be applied to a {}. Allowed targets for these hooks are: function, type, module",
1004 ann.name, target_label
1005 ),
1006 location: Some(self.span_to_source_location(span)),
1007 });
1008 }
1009
1010 if compiled.allowed_targets.is_empty() || compiled.allowed_targets.contains(&target_kind) {
1011 return Ok(());
1012 }
1013
1014 let allowed: Vec<String> = compiled
1015 .allowed_targets
1016 .iter()
1017 .map(|k| format!("{:?}", k).to_lowercase())
1018 .collect();
1019 let target_label = format!("{:?}", target_kind).to_lowercase();
1020
1021 let span = if ann.span == shape_ast::ast::Span::DUMMY {
1022 fallback_span
1023 } else {
1024 ann.span
1025 };
1026
1027 Err(ShapeError::SemanticError {
1028 message: format!(
1029 "Annotation '{}' cannot be applied to a {}. Allowed targets: {}",
1030 ann.name,
1031 target_label,
1032 allowed.join(", ")
1033 ),
1034 location: Some(self.span_to_source_location(span)),
1035 })
1036 }
1037
1038 pub fn compile(mut self, program: &Program) -> Result<BytecodeProgram> {
1040 let mut program = program.clone();
1042 shape_ast::transform::desugar_program(&mut program);
1043 let analysis_program =
1044 shape_ast::transform::augment_program_with_generated_extends(&program);
1045
1046 let mut known_bindings: Vec<String> = self.module_bindings.keys().cloned().collect();
1049 let namespace_bindings = Self::collect_namespace_import_bindings(&analysis_program);
1050 for item in &analysis_program.items {
1052 if let shape_ast::ast::Item::Import(import_stmt, _) = item {
1053 if import_stmt.from.is_empty() {
1054 continue;
1055 }
1056 match &import_stmt.items {
1057 shape_ast::ast::ImportItems::Namespace { name, alias } => {
1058 let local_name = alias.clone().unwrap_or_else(|| name.clone());
1059 self.module_scope_sources
1060 .entry(local_name)
1061 .or_insert_with(|| import_stmt.from.clone());
1062 }
1063 shape_ast::ast::ImportItems::Named(specs) => {
1064 if specs.iter().any(|spec| spec.is_annotation) {
1065 let hidden_module_name =
1066 crate::module_resolution::hidden_annotation_import_module_name(
1067 &import_stmt.from,
1068 );
1069 self.module_scope_sources
1070 .entry(hidden_module_name)
1071 .or_insert_with(|| import_stmt.from.clone());
1072 }
1073 }
1074 }
1075 }
1076 }
1077 known_bindings.extend(namespace_bindings.iter().cloned());
1078 self.module_namespace_bindings
1079 .extend(namespace_bindings.into_iter());
1080 for namespace in self.module_namespace_bindings.clone() {
1081 let binding_idx = self.get_or_create_module_binding(&namespace);
1082 self.register_extension_module_schema(&namespace);
1083 let module_schema_name = format!("__mod_{}", namespace);
1084 if self
1085 .type_tracker
1086 .schema_registry()
1087 .get(&module_schema_name)
1088 .is_some()
1089 {
1090 self.set_module_binding_type_info(binding_idx, &module_schema_name);
1091 }
1092 }
1093 known_bindings.sort();
1094 known_bindings.dedup();
1095 let analysis_mode = if matches!(self.type_diagnostic_mode, TypeDiagnosticMode::RecoverAll) {
1096 TypeAnalysisMode::RecoverAll
1097 } else {
1098 TypeAnalysisMode::FailFast
1099 };
1100 if let Err(errors) = analyze_program_with_mode(
1101 &analysis_program,
1102 self.source_text.as_deref(),
1103 None,
1104 Some(&known_bindings),
1105 analysis_mode,
1106 ) {
1107 match self.type_diagnostic_mode {
1108 TypeDiagnosticMode::Strict => {
1109 return Err(Self::type_errors_to_shape(errors));
1110 }
1111 TypeDiagnosticMode::ReliableOnly => {
1112 let strict_errors: Vec<_> = errors
1113 .into_iter()
1114 .filter(|error| Self::should_emit_type_diagnostic(&error.error))
1115 .collect();
1116 if !strict_errors.is_empty() {
1117 return Err(Self::type_errors_to_shape(strict_errors));
1118 }
1119 }
1120 TypeDiagnosticMode::RecoverAll => {
1121 self.errors.extend(
1122 errors
1123 .into_iter()
1124 .map(Self::type_error_with_location_to_shape),
1125 );
1126 }
1127 }
1128 }
1129
1130 let (inferred_ref_params, inferred_ref_mutates, inferred_param_type_hints) =
1131 Self::infer_reference_model(&program);
1132 self.inferred_param_pass_modes =
1133 Self::build_param_pass_mode_map(&program, &inferred_ref_params, &inferred_ref_mutates);
1134 self.inferred_ref_params = inferred_ref_params;
1135 self.inferred_ref_mutates = inferred_ref_mutates;
1136 self.inferred_param_type_hints = inferred_param_type_hints;
1137
1138 {
1159 use shape_runtime::type_system::inference::PropertyAssignmentCollector;
1160 let assignments = PropertyAssignmentCollector::collect(&program);
1161 let grouped = PropertyAssignmentCollector::group_by_variable(&assignments);
1162 for (var_name, var_assignments) in grouped {
1163 let field_names: Vec<String> =
1164 var_assignments.iter().map(|a| a.property.clone()).collect();
1165 self.hoisted_fields.insert(var_name, field_names);
1166 }
1167 }
1168
1169 for item in &program.items {
1171 self.register_item_functions(item)?;
1172 }
1173
1174 if let Err(e) = self.analyze_non_function_items_with_mir("__main__", &program.items) {
1178 self.errors.push(e);
1179 }
1180
1181 self.current_blob_builder = Some(FunctionBlobBuilder::new(
1183 "__main__".to_string(),
1184 self.program.current_offset(),
1185 self.program.constants.len(),
1186 self.program.strings.len(),
1187 ));
1188
1189 self.push_drop_scope();
1192 self.non_function_mir_context_stack
1193 .push("__main__".to_string());
1194
1195 let item_count = program.items.len();
1197 for (idx, item) in program.items.iter().enumerate() {
1198 let is_last = idx == item_count - 1;
1199 let future_names =
1200 self.future_reference_use_names_for_remaining_items(&program.items[idx + 1..]);
1201 self.push_future_reference_use_names(future_names);
1202 let compile_result = self.compile_item_with_context(item, is_last);
1203 self.pop_future_reference_use_names();
1204 if let Err(e) = compile_result {
1205 self.errors.push(e);
1206 }
1207 self.release_unused_module_reference_borrows_for_remaining_items(
1208 &program.items[idx + 1..],
1209 );
1210 }
1211 self.non_function_mir_context_stack.pop();
1212
1213 if !self.errors.is_empty() {
1215 if self.errors.len() == 1 {
1216 return Err(self.errors.remove(0));
1217 }
1218 return Err(shape_ast::error::ShapeError::MultiError(self.errors));
1219 }
1220
1221 self.pop_drop_scope()?;
1223
1224 {
1226 let bindings: Vec<(u16, bool)> = std::mem::take(&mut self.drop_module_bindings);
1227 for (binding_idx, is_async) in bindings.into_iter().rev() {
1228 self.emit_drop_call_for_module_binding(binding_idx, is_async);
1229 }
1230 }
1231
1232 self.emit(Instruction::simple(OpCode::Halt));
1234
1235 let mut module_binding_names = vec![String::new(); self.module_bindings.len()];
1238 for (name, &idx) in &self.module_bindings {
1239 module_binding_names[idx as usize] = name.clone();
1240 }
1241 self.program.module_binding_names = module_binding_names;
1242
1243 self.program.top_level_locals_count = self.next_local;
1245
1246 self.populate_program_storage_hints();
1248
1249 self.program.type_schema_registry = self.type_tracker.schema_registry().clone();
1251
1252 self.program.expanded_function_defs = self.function_defs.clone();
1254
1255 self.build_content_addressed_program();
1257
1258 self.program.content_addressed = self.content_addressed_program.take();
1260 if self.program.functions.is_empty() {
1261 self.program.function_blob_hashes.clear();
1262 } else {
1263 if self.function_hashes_by_id.len() < self.program.functions.len() {
1264 self.function_hashes_by_id
1265 .resize(self.program.functions.len(), None);
1266 } else if self.function_hashes_by_id.len() > self.program.functions.len() {
1267 self.function_hashes_by_id
1268 .truncate(self.program.functions.len());
1269 }
1270 self.program.function_blob_hashes = self.function_hashes_by_id.clone();
1271 }
1272
1273 if let Some(source) = self.source_text {
1275 self.program.debug_info.source_text = source.clone();
1277 if self.program.debug_info.source_map.files.is_empty() {
1279 self.program
1280 .debug_info
1281 .source_map
1282 .add_file("<main>".to_string());
1283 }
1284 if self.program.debug_info.source_map.source_texts.is_empty() {
1285 self.program
1286 .debug_info
1287 .source_map
1288 .set_source_text(0, source);
1289 }
1290 }
1291
1292 Ok(self.program)
1293 }
1294
1295 pub fn compile_with_source(
1297 mut self,
1298 program: &Program,
1299 source: &str,
1300 ) -> Result<BytecodeProgram> {
1301 self.set_source(source);
1302 self.compile(program)
1303 }
1304
1305 pub fn compile_with_graph(
1312 self,
1313 root_program: &Program,
1314 graph: std::sync::Arc<crate::module_graph::ModuleGraph>,
1315 ) -> Result<BytecodeProgram> {
1316 self.compile_with_graph_and_prelude(root_program, graph, &[])
1317 }
1318
1319 pub fn compile_with_graph_and_prelude(
1325 mut self,
1326 root_program: &Program,
1327 graph: std::sync::Arc<crate::module_graph::ModuleGraph>,
1328 _prelude_paths: &[String],
1329 ) -> Result<BytecodeProgram> {
1330 use crate::module_graph::ModuleSourceKind;
1331
1332 self.module_graph = Some(graph.clone());
1333
1334 for &dep_id in graph.topo_order() {
1336 let dep_node = graph.node(dep_id);
1337 match dep_node.source_kind {
1338 ModuleSourceKind::NativeModule => {
1339 self.register_graph_imports_for_module(dep_id, &graph)?;
1340 }
1341 ModuleSourceKind::ShapeSource | ModuleSourceKind::Hybrid => {
1342 self.compile_module_from_graph(dep_id, &graph)?;
1343 }
1344 ModuleSourceKind::CompiledBytecode => {
1345 return Err(shape_ast::error::ShapeError::ModuleError {
1347 message: format!(
1348 "Module '{}' is only available as pre-compiled bytecode",
1349 dep_node.canonical_path
1350 ),
1351 module_path: None,
1352 });
1353 }
1354 }
1355 }
1356
1357 self.register_graph_imports_for_module(graph.root_id(), &graph)?;
1360
1361 let mut stripped_program = root_program.clone();
1363 stripped_program.items.retain(|item| !matches!(item, shape_ast::ast::Item::Import(..)));
1364
1365 self.compile(&stripped_program)
1367 }
1368
1369 fn compile_module_from_graph(
1375 &mut self,
1376 module_id: crate::module_graph::ModuleId,
1377 graph: &crate::module_graph::ModuleGraph,
1378 ) -> Result<()> {
1379 let node = graph.node(module_id);
1380 let ast = match &node.ast {
1381 Some(ast) => ast.clone(),
1382 None => return Ok(()), };
1384
1385 let module_path = node.canonical_path.clone();
1386
1387 let prev_allow = self.allow_internal_builtins;
1390 if module_path.starts_with("std::") {
1391 self.allow_internal_builtins = true;
1392 }
1393
1394 self.module_scope_stack.push(module_path.clone());
1395
1396 self.register_graph_imports_for_module(module_id, graph)?;
1398
1399 let mut qualified_items = Vec::new();
1401 for item in &ast.items {
1402 if matches!(item, shape_ast::ast::Item::Import(..)) {
1403 continue;
1404 }
1405 qualified_items.push(self.qualify_module_item(item, &module_path)?);
1406 }
1407
1408 for item in &qualified_items {
1410 self.register_missing_module_items(item)?;
1411 }
1412
1413 self.non_function_mir_context_stack
1415 .push(module_path.clone());
1416 let compile_result = (|| -> Result<()> {
1417 for (idx, qualified) in qualified_items.iter().enumerate() {
1418 let future_names = self
1419 .future_reference_use_names_for_remaining_items(&qualified_items[idx + 1..]);
1420 self.push_future_reference_use_names(future_names);
1421 let result = self.compile_item_with_context(qualified, false);
1422 self.pop_future_reference_use_names();
1423 result?;
1424 self.release_unused_module_reference_borrows_for_remaining_items(
1425 &qualified_items[idx + 1..],
1426 );
1427 }
1428 Ok(())
1429 })();
1430 self.non_function_mir_context_stack.pop();
1431 compile_result?;
1432
1433 let exports = self.collect_module_runtime_exports(
1435 &ast.items
1436 .iter()
1437 .filter(|i| !matches!(i, shape_ast::ast::Item::Import(..)))
1438 .cloned()
1439 .collect::<Vec<_>>(),
1440 &module_path,
1441 );
1442 let span = shape_ast::ast::Span::default();
1443 let entries: Vec<shape_ast::ast::ObjectEntry> = exports
1444 .into_iter()
1445 .map(|(name, value_ident)| shape_ast::ast::ObjectEntry::Field {
1446 key: name,
1447 value: shape_ast::ast::Expr::Identifier(value_ident, span),
1448 type_annotation: None,
1449 })
1450 .collect();
1451 let module_object = shape_ast::ast::Expr::Object(entries, span);
1452 self.compile_expr(&module_object)?;
1453
1454 let binding_idx = self.get_or_create_module_binding(&module_path);
1455 self.emit(Instruction::new(
1456 OpCode::StoreModuleBinding,
1457 Some(Operand::ModuleBinding(binding_idx)),
1458 ));
1459 self.propagate_initializer_type_to_slot(binding_idx, false, false);
1460
1461 self.module_scope_stack.pop();
1462 self.allow_internal_builtins = prev_allow;
1463 Ok(())
1464 }
1465
1466 pub fn compile_module_ast(
1477 module_ast: &Program,
1478 ) -> Result<(BytecodeProgram, HashMap<String, usize>)> {
1479 let mut compiler = BytecodeCompiler::new();
1480 compiler.allow_internal_builtins = true;
1482 let bytecode = compiler.compile(module_ast)?;
1483
1484 let mut export_map = HashMap::new();
1486 for (idx, func) in bytecode.functions.iter().enumerate() {
1487 export_map.insert(func.name.clone(), idx);
1488 }
1489
1490 Ok((bytecode, export_map))
1491 }
1492}