1use std::collections::{HashMap, HashSet};
8use std::sync::OnceLock;
9
10use shape_ast::ast::{
11 Expr, InterfaceMember, Item, Literal, ObjectEntry, ObjectTypeField, Pattern, Program,
12 Statement, TraitMember, TypeAnnotation, VariableDecl,
13};
14use shape_runtime::metadata::UnifiedMetadata;
15use shape_runtime::schema_cache::{
16 DataSourceSchemaCache, EntitySchema, SourceSchema, default_cache_path,
17 load_cached_source_for_uri_with_diagnostics,
18};
19use shape_runtime::type_system::{
20 PropertyAssignmentCollector, Type, TypeInferenceEngine, TypeScheme,
21};
22use shape_runtime::visitor::{Visitor, walk_program};
23use shape_vm::compiler::ParamPassMode;
24use std::path::{Path, PathBuf};
25
26static UNIFIED_METADATA: OnceLock<UnifiedMetadata> = OnceLock::new();
28
29pub fn unified_metadata() -> &'static UnifiedMetadata {
30 UNIFIED_METADATA.get_or_init(UnifiedMetadata::load)
31}
32
33pub fn type_annotation_to_string(ta: &TypeAnnotation) -> Option<String> {
35 match ta {
36 TypeAnnotation::Basic(s) => Some(s.clone()),
37 TypeAnnotation::Array(inner) => {
38 type_annotation_to_string(inner).map(|s| format!("{}[]", s))
39 }
40 TypeAnnotation::Reference(s) => Some(s.to_string()),
41 TypeAnnotation::Generic { name, args } => {
42 let arg_strs: Vec<String> = args.iter().filter_map(type_annotation_to_string).collect();
43 Some(format!("{}<{}>", name, arg_strs.join(", ")))
44 }
45 TypeAnnotation::Void => Some("()".to_string()),
46 TypeAnnotation::Never => Some("never".to_string()),
47 TypeAnnotation::Null => Some("None".to_string()),
48 TypeAnnotation::Undefined => Some("undefined".to_string()),
49 TypeAnnotation::Dyn(traits) => Some(format!("dyn {}", traits.join(" + "))),
50 TypeAnnotation::Tuple(items) => {
51 let strs: Vec<String> = items.iter().filter_map(type_annotation_to_string).collect();
52 Some(format!("({})", strs.join(", ")))
53 }
54 TypeAnnotation::Object(fields) => Some(format_object_shape_from_type_fields(fields)),
55 TypeAnnotation::Function { .. } => Some("Function".to_string()),
56 TypeAnnotation::Union(types) => {
57 let strs: Vec<String> = types.iter().filter_map(type_annotation_to_string).collect();
58 Some(strs.join(" | "))
59 }
60 TypeAnnotation::Intersection(types) => {
61 let strs: Vec<String> = types.iter().filter_map(type_annotation_to_string).collect();
62 merge_structural_intersection_shapes(&strs).or_else(|| Some(strs.join(" + ")))
63 }
64 }
65}
66
67pub fn infer_expr_type(expr: &Expr) -> Option<String> {
69 let env = HashMap::new();
70 infer_expr_type_with_env(expr, &env)
71}
72
73fn infer_expr_type_with_env(expr: &Expr, env: &HashMap<String, String>) -> Option<String> {
74 match expr {
75 Expr::Literal(lit, _) => Some(infer_literal_type(lit)),
76 Expr::FunctionCall { name, .. } => infer_function_return_type(name),
77 Expr::QualifiedFunctionCall {
78 namespace, function, ..
79 } => infer_function_return_type(&format!("{}::{}", namespace, function)),
80 Expr::EnumConstructor { enum_name, .. } => Some(enum_name.to_string()),
81 Expr::MethodCall {
82 receiver, method, ..
83 } => match method.as_str() {
84 "filter" | "where" | "head" | "tail" | "slice" | "reverse" | "concat" | "orderBy"
86 | "limit" | "sort" | "execute" => infer_expr_type_with_env(receiver, env),
87 "sum" | "mean" | "avg" | "min" | "max" | "count" | "reduce" => {
89 Some("number".to_string())
90 }
91 "toString" | "to_string" | "toFixed" => Some("string".to_string()),
93 "type" => Some("Type".to_string()),
95 "length" | "len" => Some("number".to_string()),
97 "isEmpty" | "contains" | "startsWith" | "endsWith" | "some" | "every" | "is_ok"
99 | "is_err" | "is_some" | "is_none" => Some("bool".to_string()),
100 "unwrap" | "unwrap_or" => {
102 if let Some(receiver_type) = infer_expr_type_with_env(receiver, env) {
103 extract_wrapper_inner(&receiver_type)
104 } else {
105 None
106 }
107 }
108 "map" => Some("Array".to_string()),
110 _ => None,
111 },
112 Expr::BinaryOp {
113 op, left, right, ..
114 } => {
115 use shape_ast::ast::BinaryOp;
116 match op {
117 BinaryOp::Equal
118 | BinaryOp::NotEqual
119 | BinaryOp::Less
120 | BinaryOp::LessEq
121 | BinaryOp::Greater
122 | BinaryOp::GreaterEq
123 | BinaryOp::And
124 | BinaryOp::Or
125 | BinaryOp::FuzzyEqual
126 | BinaryOp::FuzzyGreater
127 | BinaryOp::FuzzyLess => Some("bool".to_string()),
128 BinaryOp::Add => {
129 let left_type = infer_expr_type_with_env(left, env);
130 let right_type = infer_expr_type_with_env(right, env);
131 infer_add_type(left_type.as_deref(), right_type.as_deref())
132 .or_else(|| Some("number".to_string()))
133 }
134 BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod | BinaryOp::Pow => {
135 let left_type = infer_expr_type_with_env(left, env);
136 let right_type = infer_expr_type_with_env(right, env);
137 infer_numeric_arithmetic_type(left_type.as_deref(), right_type.as_deref())
138 .or_else(|| Some("number".to_string()))
139 }
140 BinaryOp::NullCoalesce => None,
141 BinaryOp::ErrorContext => Some("Result".to_string()),
142 BinaryOp::Pipe => {
143 if let Some(right_type) = infer_expr_type_with_env(right, env) {
146 Some(right_type)
147 } else {
148 infer_expr_type_with_env(left, env)
150 }
151 }
152 BinaryOp::BitAnd
153 | BinaryOp::BitOr
154 | BinaryOp::BitXor
155 | BinaryOp::BitShl
156 | BinaryOp::BitShr => Some("number".to_string()),
157 }
158 }
159 Expr::Array(elements, _) => Some(infer_array_type(elements)),
160 Expr::Object(entries, _) => Some(infer_object_shape(entries)),
161 Expr::DataRef(_, _) => Some("Row".to_string()),
162 Expr::TryOperator(inner, _) => {
163 if let Some(inner_type) = infer_expr_type_with_env(inner, env) {
164 extract_wrapper_inner(&inner_type)
165 } else {
166 None
167 }
168 }
169 Expr::UsingImpl { expr, .. } => infer_expr_type_with_env(expr, env),
170 Expr::Identifier(name, _) => env.get(name).cloned(),
171 Expr::DataDateTimeRef(_, _) => Some("Data".to_string()),
172 Expr::DataRelativeAccess { .. } => Some("Data".to_string()),
173 Expr::PropertyAccess { .. } => None,
174 Expr::IndexAccess { .. } => None,
175 Expr::UnaryOp { op, .. } => {
176 use shape_ast::ast::UnaryOp;
177 match op {
178 UnaryOp::Not => Some("bool".to_string()),
179 UnaryOp::Neg => Some("number".to_string()),
180 UnaryOp::BitNot => Some("number".to_string()),
181 }
182 }
183 Expr::TimeRef(_, _) => Some("Time".to_string()),
184 Expr::DateTime(_, _) => Some("DateTime".to_string()),
185 Expr::PatternRef(_, _) => Some("Pattern".to_string()),
186 Expr::Conditional { then_expr, .. } => infer_expr_type_with_env(then_expr, env),
187 Expr::Block(_, _) => None,
188 Expr::TypeAssertion {
189 type_annotation, ..
190 } => type_annotation_to_string(type_annotation),
191 Expr::InstanceOf { .. } => Some("bool".to_string()),
192 Expr::FunctionExpr { .. } => Some("Function".to_string()),
193 Expr::Duration(_, _) => Some("Duration".to_string()),
194 Expr::Spread(_, _) => None,
195 Expr::If(_, _) => None,
196 Expr::While(_, _) => None,
197 Expr::For(_, _) => None,
198 Expr::Loop(_, _) => None,
199 Expr::Let(_, _) => None,
200 Expr::Assign(_, _) => None,
201 Expr::Break(_, _) => None,
202 Expr::Continue(_) => None,
203 Expr::Return(_, _) => None,
204 Expr::Match(match_expr, _) => {
205 let mut arm_types: Vec<String> = match_expr
206 .arms
207 .iter()
208 .filter_map(|arm| {
209 let mut arm_env = env.clone();
210 collect_typed_pattern_bindings(&arm.pattern, &mut arm_env);
211 infer_expr_type_with_env(&arm.body, &arm_env)
212 })
213 .collect();
214 if arm_types.is_empty() {
215 None
216 } else {
217 arm_types.sort();
218 arm_types.dedup();
219 match arm_types.len() {
220 0 => None,
221 1 => arm_types.into_iter().next(),
222 _ => Some(arm_types.join(" | ")),
223 }
224 }
225 }
226 Expr::Unit(_) => Some("()".to_string()),
227 Expr::Range { .. } => Some("Range".to_string()),
228 Expr::TimeframeContext { expr, .. } => infer_expr_type_with_env(expr, env),
229 Expr::ListComprehension(_, _) => Some("Array".to_string()),
230 Expr::SimulationCall { .. } => Some("SimulationResult".to_string()),
231 Expr::WindowExpr(_, _) => Some("Number".to_string()),
232 Expr::FuzzyComparison { .. } => Some("bool".to_string()),
233 Expr::FromQuery(_, _) => Some("Array".to_string()),
234 Expr::StructLiteral { type_name, .. } => Some(type_name.to_string()),
235 Expr::Await(inner, _) => infer_expr_type_with_env(inner, env),
236 Expr::Join(_, _) => Some("Array".to_string()),
237 Expr::Annotated { target, .. } => infer_expr_type_with_env(target, env),
238 Expr::AsyncLet(_, _) => None,
239 Expr::AsyncScope(inner, _) => infer_expr_type_with_env(inner, env),
240 Expr::Comptime(_, _) => None,
241 Expr::ComptimeFor(_, _) => None,
242 Expr::Reference { expr: inner, .. } => infer_expr_type_with_env(inner, env),
243 Expr::TableRows(..) => Some("Table".to_string()),
244 }
245}
246
247pub fn infer_literal_type(lit: &Literal) -> String {
249 match lit {
250 Literal::Int(_) => "int".to_string(),
251 Literal::UInt(_) => "u64".to_string(),
252 Literal::TypedInt(_, w) => w.type_name().to_string(),
253 Literal::Number(_) => "number".to_string(),
254 Literal::Decimal(_) => "decimal".to_string(),
255 Literal::String(_) => "string".to_string(),
256 Literal::FormattedString { .. } => "string".to_string(),
257 Literal::ContentString { .. } => "string".to_string(),
258 Literal::Bool(_) => "bool".to_string(),
259 Literal::Char(_) => "char".to_string(),
260 Literal::None => "Option".to_string(),
261 Literal::Unit => "()".to_string(),
262 Literal::Timeframe(_) => "Timeframe".to_string(),
263 }
264}
265
266pub fn extract_wrapper_inner(type_name: &str) -> Option<String> {
268 if type_name.starts_with("Result<") && type_name.ends_with('>') {
269 let inner = &type_name[7..type_name.len() - 1];
270 if let Some(comma_pos) = inner.find(',') {
271 return Some(inner[..comma_pos].trim().to_string());
272 }
273 return Some(inner.to_string());
274 }
275 if type_name.starts_with("Option<") && type_name.ends_with('>') {
276 let inner = &type_name[7..type_name.len() - 1];
277 return Some(inner.to_string());
278 }
279 if type_name.ends_with('?') {
280 return Some(type_name[..type_name.len() - 1].to_string());
281 }
282 Some(type_name.to_string())
283}
284
285pub fn infer_function_return_type(name: &str) -> Option<String> {
287 unified_metadata()
288 .get_function(name)
289 .map(|f| f.return_type.clone())
290}
291
292fn infer_array_type(elements: &[Expr]) -> String {
294 if elements.is_empty() {
295 return "Array".to_string();
296 }
297 if let Some(first_type) = infer_expr_type(&elements[0]) {
298 let all_same = elements
299 .iter()
300 .skip(1)
301 .all(|e| infer_expr_type(e).as_deref() == Some(first_type.as_str()));
302 if all_same {
303 format!("{}[]", first_type)
304 } else {
305 "Array".to_string()
306 }
307 } else {
308 "Array".to_string()
309 }
310}
311
312fn format_object_shape_from_type_fields(fields: &[ObjectTypeField]) -> String {
313 if fields.is_empty() {
314 return "{}".to_string();
315 }
316
317 let parts: Vec<String> = fields
318 .iter()
319 .map(|field| {
320 let field_type = type_annotation_to_string(&field.type_annotation)
321 .unwrap_or_else(|| "unknown".to_string());
322 if field.optional {
323 format!("{}?: {}", field.name, field_type)
324 } else {
325 format!("{}: {}", field.name, field_type)
326 }
327 })
328 .collect();
329 format!("{{ {} }}", parts.join(", "))
330}
331
332fn split_top_level(input: &str, delimiter: char) -> Vec<String> {
333 let mut parts = Vec::new();
334 let mut start = 0usize;
335 let mut paren_depth = 0usize;
336 let mut bracket_depth = 0usize;
337 let mut brace_depth = 0usize;
338 let mut angle_depth = 0usize;
339
340 for (idx, ch) in input.char_indices() {
341 match ch {
342 '(' => paren_depth += 1,
343 ')' => paren_depth = paren_depth.saturating_sub(1),
344 '[' => bracket_depth += 1,
345 ']' => bracket_depth = bracket_depth.saturating_sub(1),
346 '{' => brace_depth += 1,
347 '}' => brace_depth = brace_depth.saturating_sub(1),
348 '<' => angle_depth += 1,
349 '>' => angle_depth = angle_depth.saturating_sub(1),
350 _ => {}
351 }
352
353 if ch == delimiter
354 && paren_depth == 0
355 && bracket_depth == 0
356 && brace_depth == 0
357 && angle_depth == 0
358 {
359 parts.push(input[start..idx].trim().to_string());
360 start = idx + ch.len_utf8();
361 }
362 }
363
364 parts.push(input[start..].trim().to_string());
365 parts.into_iter().filter(|part| !part.is_empty()).collect()
366}
367
368pub fn is_structural_object_shape(type_name: &str) -> bool {
369 let t = type_name.trim();
370 t.starts_with('{') && t.ends_with('}')
371}
372
373fn is_generic_object_type(type_name: &str) -> bool {
374 type_name.trim().eq_ignore_ascii_case("object")
375}
376
377pub fn parse_object_shape_fields(shape: &str) -> Option<Vec<(String, String)>> {
378 let trimmed = shape.trim();
379 if !is_structural_object_shape(trimmed) {
380 return None;
381 }
382
383 let inner = trimmed
384 .strip_prefix('{')
385 .and_then(|s| s.strip_suffix('}'))?
386 .trim();
387 if inner.is_empty() {
388 return Some(Vec::new());
389 }
390
391 let mut fields = Vec::new();
392 for part in split_top_level(inner, ',') {
393 if part.starts_with("...") {
394 continue;
395 }
396 let (name, ty) = part.split_once(':')?;
397 let field_name = name.trim().trim_end_matches('?').trim().to_string();
398 let field_type = ty.trim().to_string();
399 if field_name.is_empty() || field_type.is_empty() {
400 return None;
401 }
402 fields.push((field_name, field_type));
403 }
404 Some(fields)
405}
406
407pub fn format_object_shape(fields: &[(String, String)]) -> String {
408 if fields.is_empty() {
409 return "{}".to_string();
410 }
411 let field_strs: Vec<String> = fields
412 .iter()
413 .map(|(name, ty)| format!("{}: {}", name, ty))
414 .collect();
415 format!("{{ {} }}", field_strs.join(", "))
416}
417
418pub fn merge_object_shapes(left: &str, right: &str) -> Option<String> {
419 let mut merged = parse_object_shape_fields(left)?;
420 let right_fields = parse_object_shape_fields(right)?;
421
422 for (name, ty) in right_fields {
423 if !merged.iter().any(|(existing, _)| existing == &name) {
424 merged.push((name, ty));
425 }
426 }
427
428 Some(format_object_shape(&merged))
429}
430
431fn merge_structural_intersection_shapes(parts: &[String]) -> Option<String> {
432 let mut iter = parts.iter();
433 let first = iter.next()?;
434 if !is_structural_object_shape(first) {
435 return None;
436 }
437
438 let mut merged = first.clone();
439 for part in iter {
440 if !is_structural_object_shape(part) {
441 return None;
442 }
443 merged = merge_object_shapes(&merged, part)?;
444 }
445 Some(merged)
446}
447
448fn infer_add_type(left: Option<&str>, right: Option<&str>) -> Option<String> {
449 let (Some(left), Some(right)) = (left, right) else {
450 return None;
451 };
452
453 if left == "string" || right == "string" {
454 return Some("string".to_string());
455 }
456
457 if is_structural_object_shape(left) && is_structural_object_shape(right) {
458 return merge_object_shapes(left, right);
459 }
460
461 infer_numeric_arithmetic_type(Some(left), Some(right))
462}
463
464fn infer_numeric_arithmetic_type(left: Option<&str>, right: Option<&str>) -> Option<String> {
465 let (Some(left), Some(right)) = (left, right) else {
466 return None;
467 };
468 if !is_numeric_type_name(left) || !is_numeric_type_name(right) {
469 return None;
470 }
471 if left == right {
472 return Some(left.to_string());
473 }
474 Some("number".to_string())
475}
476
477fn is_numeric_type_name(ty: &str) -> bool {
478 matches!(
479 ty,
480 "int" | "number" | "decimal" | "float" | "integer" | "f64" | "i64"
481 )
482}
483
484fn collect_typed_pattern_bindings(pattern: &Pattern, env: &mut HashMap<String, String>) {
485 match pattern {
486 Pattern::Typed {
487 name,
488 type_annotation,
489 } => {
490 if let Some(type_name) = type_annotation_to_string(type_annotation) {
491 env.insert(name.clone(), type_name);
492 }
493 }
494 Pattern::Array(patterns) => {
495 for pat in patterns {
496 collect_typed_pattern_bindings(pat, env);
497 }
498 }
499 Pattern::Object(fields) => {
500 for (_, pat) in fields {
501 collect_typed_pattern_bindings(pat, env);
502 }
503 }
504 Pattern::Constructor { fields, .. } => match fields {
505 shape_ast::ast::PatternConstructorFields::Tuple(patterns) => {
506 for pat in patterns {
507 collect_typed_pattern_bindings(pat, env);
508 }
509 }
510 shape_ast::ast::PatternConstructorFields::Struct(fields) => {
511 for (_, pat) in fields {
512 collect_typed_pattern_bindings(pat, env);
513 }
514 }
515 shape_ast::ast::PatternConstructorFields::Unit => {}
516 },
517 Pattern::Identifier(_) | Pattern::Literal(_) | Pattern::Wildcard => {}
518 }
519}
520
521pub fn infer_object_shape(entries: &[ObjectEntry]) -> String {
523 format_object_shape(&collect_object_fields(entries))
524}
525
526pub fn extract_struct_fields(
533 program: &Program,
534) -> std::collections::HashMap<String, Vec<(String, String)>> {
535 use shape_ast::ast::Statement;
536
537 let mut result = std::collections::HashMap::new();
538
539 for item in &program.items {
541 if let Item::StructType(struct_def, _) = item {
542 let fields: Vec<(String, String)> = struct_def
543 .fields
544 .iter()
545 .map(|f| {
546 let mut type_str = type_annotation_to_string(&f.type_annotation)
547 .unwrap_or_else(|| "unknown".to_string());
548 if f.is_comptime {
549 let default_repr = f
551 .default_value
552 .as_ref()
553 .map(|expr| match expr {
554 Expr::Literal(shape_ast::ast::Literal::String(s), _) => {
555 format!(" = \"{}\"", s)
556 }
557 Expr::Literal(shape_ast::ast::Literal::Number(n), _) => {
558 format!(" = {}", n)
559 }
560 Expr::Literal(shape_ast::ast::Literal::Int(n), _) => {
561 format!(" = {}", n)
562 }
563 Expr::Literal(shape_ast::ast::Literal::Bool(b), _) => {
564 format!(" = {}", b)
565 }
566 _ => String::new(),
567 })
568 .unwrap_or_default();
569 type_str = format!("comptime {}{}", type_str, default_repr);
570 }
571 (f.name.clone(), type_str)
572 })
573 .collect();
574 result.insert(struct_def.name.clone(), fields);
575 }
576 }
577
578 for item in &program.items {
580 let value_expr = match item {
581 Item::VariableDecl(decl, _) => decl.value.as_ref(),
582 Item::Statement(Statement::VariableDecl(decl, _), _) => decl.value.as_ref(),
583 _ => None,
584 };
585 if let Some(Expr::StructLiteral {
586 type_name, fields, ..
587 }) = value_expr
588 {
589 if !result.contains_key(type_name.as_str()) {
590 let inferred: Vec<(String, String)> = fields
591 .iter()
592 .map(|(name, expr)| {
593 let type_str =
594 infer_expr_type(expr).unwrap_or_else(|| "unknown".to_string());
595 (name.clone(), type_str)
596 })
597 .collect();
598 result.insert(type_name.to_string(), inferred);
599 }
600 }
601 }
602
603 result
604}
605
606fn parse_named_generic_type(type_name: &str) -> Option<(String, Vec<String>)> {
607 let trimmed = type_name.trim();
608 let start = trimmed.find('<')?;
609 let end = trimmed.rfind('>')?;
610 if end <= start {
611 return None;
612 }
613 let base = trimmed[..start].trim().to_string();
614 let inner = trimmed[start + 1..end].trim();
615 if inner.is_empty() {
616 return Some((base, Vec::new()));
617 }
618 Some((base, split_top_level(inner, ',')))
619}
620
621fn replace_type_identifier(input: &str, identifier: &str, replacement: &str) -> String {
622 if identifier.is_empty() {
623 return input.to_string();
624 }
625
626 let mut out = String::with_capacity(input.len());
627 let mut token = String::new();
628 let mut token_started = false;
629
630 let flush_token = |token: &mut String, out: &mut String| {
631 if token.is_empty() {
632 return;
633 }
634 if token == identifier {
635 out.push_str(replacement);
636 } else {
637 out.push_str(token);
638 }
639 token.clear();
640 };
641
642 for ch in input.chars() {
643 let is_ident_char = ch.is_ascii_alphanumeric() || ch == '_';
644 if is_ident_char {
645 token.push(ch);
646 token_started = true;
647 } else {
648 if token_started {
649 flush_token(&mut token, &mut out);
650 token_started = false;
651 }
652 out.push(ch);
653 }
654 }
655 if token_started {
656 flush_token(&mut token, &mut out);
657 }
658
659 out
660}
661
662fn substitute_type_params_in_field_type(
663 field_type: &str,
664 bindings: &HashMap<String, String>,
665) -> String {
666 let mut resolved = field_type.to_string();
667 for (param, arg) in bindings {
668 resolved = replace_type_identifier(&resolved, param, arg);
669 }
670 resolved
671}
672
673pub fn resolve_struct_field_type(
676 program: &Program,
677 type_name: &str,
678 field_name: &str,
679) -> Option<String> {
680 let (base_name, generic_args) = parse_named_generic_type(type_name)
681 .unwrap_or_else(|| (type_name.trim().to_string(), Vec::new()));
682
683 for item in &program.items {
684 let Item::StructType(struct_def, _) = item else {
685 continue;
686 };
687 if struct_def.name != base_name {
688 continue;
689 }
690
691 let field = struct_def.fields.iter().find(|f| f.name == field_name)?;
692 let mut field_type = type_annotation_to_string(&field.type_annotation)
693 .unwrap_or_else(|| "unknown".to_string());
694
695 if let Some(type_params) = &struct_def.type_params {
696 if !type_params.is_empty() {
697 let mut bindings: HashMap<String, String> = HashMap::new();
698 for (idx, param) in type_params.iter().enumerate() {
699 let bound = generic_args.get(idx).cloned().or_else(|| {
700 param
701 .default_type
702 .as_ref()
703 .and_then(type_annotation_to_string)
704 });
705 if let Some(bound) = bound {
706 bindings.insert(param.name.clone(), bound);
707 }
708 }
709 field_type = substitute_type_params_in_field_type(&field_type, &bindings);
710 }
711 }
712
713 return Some(field_type);
714 }
715
716 None
717}
718
719pub fn type_to_string(ty: &Type) -> String {
722 match ty {
723 Type::Concrete(annotation) => {
724 type_annotation_to_string(annotation).unwrap_or_else(|| "unknown".to_string())
725 }
726 Type::Generic { base, args } => {
727 let base_name = type_to_string(base);
728 if args.is_empty() {
729 base_name
730 } else {
731 let arg_list: Vec<String> = args.iter().map(type_to_string).collect();
732 format!("{}<{}>", base_name, arg_list.join(", "))
733 }
734 }
735 Type::Variable(_) => "unknown".to_string(),
736 Type::Constrained { .. } => "unknown".to_string(),
737 Type::Function { params, returns } => {
738 let param_list: Vec<String> = params.iter().map(type_to_string).collect();
739 format!("({}) -> {}", param_list.join(", "), type_to_string(returns))
740 }
741 }
742}
743
744pub fn infer_expr_type_via_engine(expr: &Expr) -> Option<String> {
747 let mut engine = TypeInferenceEngine::new();
748 match engine.infer_expr(expr) {
749 Ok(ty) => {
750 let s = type_to_string(&ty);
751 if s == "unknown" { None } else { Some(s) }
752 }
753 Err(_) => None,
754 }
755}
756
757#[derive(Debug, Clone)]
759pub enum ParamReferenceMode {
760 Shared,
761 Exclusive,
762}
763
764impl ParamReferenceMode {
765 pub fn prefix(&self) -> &'static str {
766 match self {
767 ParamReferenceMode::Shared => "&",
768 ParamReferenceMode::Exclusive => "&mut ",
769 }
770 }
771}
772
773#[derive(Debug, Clone)]
775pub struct FunctionTypeInfo {
776 pub param_types: Vec<(String, String)>,
779 pub param_ref_modes: HashMap<String, ParamReferenceMode>,
781 pub return_type: Option<String>,
783}
784
785pub fn infer_function_signatures(program: &Program) -> HashMap<String, FunctionTypeInfo> {
787 let augmented = shape_ast::transform::augment_program_with_generated_extends(program);
788 let mut engine = TypeInferenceEngine::new();
789 let mut result = HashMap::new();
790 let inferred_param_pass_modes = shape_vm::compiler::infer_param_pass_modes(&augmented);
791
792 let func_defs: Vec<&shape_ast::ast::FunctionDef> = program
794 .items
795 .iter()
796 .filter_map(|item| {
797 if let Item::Function(f, _) = item {
798 Some(f)
799 } else {
800 None
801 }
802 })
803 .collect();
804
805 let (types, _) = engine.infer_program_best_effort(&augmented);
806 let func_map: HashMap<&str, &&shape_ast::ast::FunctionDef> =
807 func_defs.iter().map(|f| (f.name.as_str(), f)).collect();
808 let mut inferred_infos: HashMap<String, FunctionTypeInfo> = HashMap::new();
809
810 for (name, ty) in &types {
811 let Some(func_def) = func_map.get(name.as_str()) else {
812 continue;
813 };
814
815 let (param_type_strings, return_type_string) = match ty {
816 Type::Function { params, returns } => (
817 params.iter().map(type_to_string).collect::<Vec<_>>(),
818 Some(type_to_string(returns)),
819 ),
820 Type::Concrete(TypeAnnotation::Function { params, returns }) => (
821 params
822 .iter()
823 .map(|p| {
824 type_annotation_to_string(&p.type_annotation)
825 .unwrap_or_else(|| "unknown".to_string())
826 })
827 .collect::<Vec<_>>(),
828 type_annotation_to_string(returns),
829 ),
830 _ => continue,
831 };
832
833 let param_types: Vec<(String, String)> = func_def
834 .params
835 .iter()
836 .zip(param_type_strings.iter())
837 .filter_map(|(ast_param, inferred_type)| {
838 if ast_param.type_annotation.is_some() {
839 return None;
840 }
841 let param_name = ast_param.simple_name()?.to_string();
842 if inferred_type == "_" || inferred_type == "unknown" {
843 return None;
844 }
845 Some((param_name, inferred_type.clone()))
846 })
847 .collect();
848 let mut param_ref_modes = HashMap::new();
849 let param_modes = inferred_param_pass_modes
850 .get(name)
851 .cloned()
852 .unwrap_or_default();
853 for (idx, ast_param) in func_def.params.iter().enumerate() {
854 let Some(param_name) = ast_param.simple_name() else {
855 continue;
856 };
857 let mode = match param_modes
858 .get(idx)
859 .copied()
860 .unwrap_or(if ast_param.is_reference {
861 ParamPassMode::ByRefShared
862 } else {
863 ParamPassMode::ByValue
864 }) {
865 ParamPassMode::ByRefExclusive => ParamReferenceMode::Exclusive,
866 ParamPassMode::ByRefShared => ParamReferenceMode::Shared,
867 ParamPassMode::ByValue => continue,
868 };
869 param_ref_modes.insert(param_name.to_string(), mode);
870 }
871
872 let return_type = if func_def.return_type.is_none() {
873 return_type_string.filter(|s| s != "_" && s != "unknown")
874 } else {
875 None
876 };
877
878 inferred_infos.insert(
879 name.clone(),
880 FunctionTypeInfo {
881 param_types,
882 param_ref_modes,
883 return_type,
884 },
885 );
886 }
887
888 for func_def in &func_defs {
889 let mut info = inferred_infos
890 .remove(&func_def.name)
891 .unwrap_or(FunctionTypeInfo {
892 param_types: Vec::new(),
893 param_ref_modes: HashMap::new(),
894 return_type: None,
895 });
896
897 if func_def.return_type.is_none() && info.return_type.is_none() {
898 info.return_type = infer_function_return_from_body_via_engine(func_def);
899 }
900
901 if func_def.return_type.is_some() && info.param_types.is_empty() {
903 continue;
904 }
905
906 if func_def.return_type.is_none()
909 || !info.param_types.is_empty()
910 || !info.param_ref_modes.is_empty()
911 || info.return_type.is_some()
912 {
913 result.insert(func_def.name.clone(), info);
914 }
915 }
916
917 for item in &program.items {
921 if let Item::ForeignFunction(foreign_fn, _) = item {
922 let ret = foreign_fn
923 .return_type
924 .as_ref()
925 .and_then(type_annotation_to_string);
926 result
927 .entry(foreign_fn.name.clone())
928 .or_insert_with(|| FunctionTypeInfo {
929 param_types: Vec::new(),
930 param_ref_modes: HashMap::new(),
931 return_type: ret,
932 });
933 }
934 }
935
936 result
937}
938
939fn infer_function_return_from_body_via_engine(
940 func_def: &shape_ast::ast::FunctionDef,
941) -> Option<String> {
942 infer_return_type_for_block_with_params(&func_def.body, Some(&func_def.params))
943}
944
945pub fn infer_block_return_type_via_engine(body: &[Statement]) -> Option<String> {
949 infer_return_type_for_block_with_params(body, None)
950}
951
952fn infer_return_type_for_block_with_params(
953 body: &[Statement],
954 params: Option<&[shape_ast::ast::FunctionParameter]>,
955) -> Option<String> {
956 let return_exprs = collect_return_expressions(body);
957 if return_exprs.is_empty() {
958 return None;
959 }
960
961 let mut engine = TypeInferenceEngine::new();
962
963 if let Some(params) = params {
964 for param in params {
965 let Some(name) = param.simple_name() else {
966 continue;
967 };
968 let Some(type_ann) = ¶m.type_annotation else {
969 continue;
970 };
971 engine
972 .env
973 .define(name, TypeScheme::mono(Type::Concrete(type_ann.clone())));
974 }
975 }
976
977 let mut inferred = Vec::new();
978 for expr in return_exprs {
979 if let Ok(ty) = engine.infer_expr(&expr) {
980 let s = type_to_string(&ty);
981 if s != "unknown" {
982 inferred.push(s);
983 continue;
984 }
985 }
986
987 if let Some(fallback) = infer_expr_type(&expr) {
990 if fallback != "unknown" {
991 inferred.push(fallback);
992 }
993 }
994 }
995
996 inferred.sort();
997 inferred.dedup();
998 match inferred.len() {
999 0 => None,
1000 1 => inferred.into_iter().next(),
1001 _ => Some(inferred.join(" | ")),
1002 }
1003}
1004
1005fn collect_return_expressions(body: &[Statement]) -> Vec<Expr> {
1006 let mut exprs = Vec::new();
1007
1008 for stmt in body {
1009 match stmt {
1010 Statement::Return(Some(expr), _) => exprs.push(expr.clone()),
1011 Statement::Expression(expr, _) => collect_return_exprs_from_expr(expr, &mut exprs),
1012 _ => {}
1013 }
1014 }
1015
1016 if let Some(Statement::Expression(expr, _)) = body.last() {
1017 if !matches!(expr, Expr::Return(_, _)) {
1018 exprs.push(expr.clone());
1019 }
1020 }
1021
1022 exprs
1023}
1024
1025fn collect_return_exprs_from_expr(expr: &Expr, out: &mut Vec<Expr>) {
1026 match expr {
1027 Expr::Return(Some(inner), _) => out.push(inner.as_ref().clone()),
1028 Expr::If(if_expr, _) => {
1029 collect_return_exprs_from_expr(&if_expr.then_branch, out);
1030 if let Some(else_branch) = &if_expr.else_branch {
1031 collect_return_exprs_from_expr(else_branch, out);
1032 }
1033 }
1034 Expr::Block(block_expr, _) => {
1035 for item in &block_expr.items {
1036 match item {
1037 shape_ast::ast::BlockItem::Statement(Statement::Expression(inner, _)) => {
1038 collect_return_exprs_from_expr(inner, out)
1039 }
1040 shape_ast::ast::BlockItem::Expression(inner) => {
1041 collect_return_exprs_from_expr(inner, out)
1042 }
1043 _ => {}
1044 }
1045 }
1046 }
1047 _ => {}
1048 }
1049}
1050
1051pub fn infer_program_types(program: &Program) -> HashMap<String, String> {
1053 infer_program_types_with_context(program, None, None, None)
1054}
1055
1056pub fn infer_program_types_with_context(
1058 program: &Program,
1059 current_file: Option<&Path>,
1060 workspace_root: Option<&Path>,
1061 current_source: Option<&str>,
1062) -> HashMap<String, String> {
1063 let augmented = shape_ast::transform::augment_program_with_generated_extends(program);
1064 let mut engine = TypeInferenceEngine::new();
1065 let mut types = HashMap::new();
1066
1067 let (inferred, _) = engine.infer_program_best_effort(&augmented);
1068 for (name, ty) in inferred {
1069 let mut s = type_to_string(&ty);
1070 if let Some(structural) = infer_variable_type(&augmented, &name) {
1071 if is_structural_object_shape(&structural) {
1072 if is_structural_object_shape(&s) {
1073 if let Some(merged) = merge_object_shapes(&s, &structural) {
1074 s = merged;
1075 }
1076 } else if is_generic_object_type(&s) {
1077 s = structural;
1078 }
1079 }
1080 }
1081 if s != "unknown" {
1082 types.insert(name, s);
1083 }
1084 }
1085
1086 augment_schema_backed_module_call_types(
1087 program,
1088 &mut types,
1089 current_file,
1090 workspace_root,
1091 current_source,
1092 );
1093
1094 types
1095}
1096
1097fn augment_schema_backed_module_call_types(
1098 program: &Program,
1099 types: &mut HashMap<String, String>,
1100 current_file: Option<&Path>,
1101 workspace_root: Option<&Path>,
1102 current_source: Option<&str>,
1103) {
1104 for item in &program.items {
1105 match item {
1106 Item::VariableDecl(var_decl, _) => {
1107 maybe_insert_schema_backed_type_from_decl(
1108 var_decl,
1109 types,
1110 current_file,
1111 workspace_root,
1112 current_source,
1113 );
1114 }
1115 Item::Statement(Statement::VariableDecl(var_decl, _), _) => {
1116 maybe_insert_schema_backed_type_from_decl(
1117 var_decl,
1118 types,
1119 current_file,
1120 workspace_root,
1121 current_source,
1122 );
1123 }
1124 _ => {}
1125 }
1126 }
1127}
1128
1129fn maybe_insert_schema_backed_type_from_decl(
1130 var_decl: &VariableDecl,
1131 types: &mut HashMap<String, String>,
1132 current_file: Option<&Path>,
1133 workspace_root: Option<&Path>,
1134 current_source: Option<&str>,
1135) {
1136 let Some(name) = var_decl.pattern.as_identifier() else {
1137 return;
1138 };
1139 let Some(value) = &var_decl.value else {
1140 return;
1141 };
1142 let Some(conn_type) =
1143 infer_schema_backed_type_from_expr(value, current_file, workspace_root, current_source)
1144 else {
1145 return;
1146 };
1147 types.insert(name.to_string(), conn_type);
1148}
1149
1150fn infer_schema_backed_type_from_expr(
1151 expr: &Expr,
1152 current_file: Option<&Path>,
1153 workspace_root: Option<&Path>,
1154 current_source: Option<&str>,
1155) -> Option<String> {
1156 let Expr::MethodCall {
1157 receiver,
1158 method,
1159 args,
1160 named_args: _,
1161 ..
1162 } = expr
1163 else {
1164 return None;
1165 };
1166 let module_name = match receiver.as_ref() {
1167 Expr::Identifier(name, _) => name.as_str(),
1168 _ => return None,
1169 };
1170 let source_schema_provider = schema_provider_for_module_call(
1171 module_name,
1172 method,
1173 args.len(),
1174 current_file,
1175 workspace_root,
1176 current_source,
1177 )?;
1178 let uri = match args.first() {
1179 Some(Expr::Literal(Literal::String(uri), _)) => Some(uri.as_str()),
1180 _ => None,
1181 }?;
1182 let source = resolve_source_schema_for_module_call(
1183 module_name,
1184 &source_schema_provider,
1185 uri,
1186 current_file,
1187 workspace_root,
1188 current_source,
1189 )?;
1190 Some(connection_shape_from_source_schema(&source))
1191}
1192
1193fn schema_provider_for_module_call(
1194 module_name: &str,
1195 function_name: &str,
1196 arg_count: usize,
1197 current_file: Option<&Path>,
1198 workspace_root: Option<&Path>,
1199 current_source: Option<&str>,
1200) -> Option<String> {
1201 let schema = crate::completion::imports::extension_module_schema_with_context(
1202 module_name,
1203 current_file,
1204 workspace_root,
1205 current_source,
1206 );
1207
1208 let Some(schema) = schema else {
1209 return (arg_count == 1).then(|| "source_schema".to_string());
1213 };
1214
1215 let export = schema.functions.iter().find(|f| f.name == function_name)?;
1216 if !is_schema_backed_connection_return(export.return_type.as_deref()) {
1217 return None;
1218 }
1219
1220 schema
1221 .functions
1222 .iter()
1223 .find(|f| f.name == "source_schema")
1224 .map(|f| f.name.clone())
1225}
1226
1227fn is_schema_backed_connection_return(return_type: Option<&str>) -> bool {
1228 let Some(return_type) = return_type else {
1229 return false;
1230 };
1231 return_type == "DbConnection" || return_type.ends_with("Connection")
1232}
1233
1234fn resolve_source_schema_for_module_call(
1235 module_name: &str,
1236 source_schema_provider: &str,
1237 uri: &str,
1238 current_file: Option<&Path>,
1239 workspace_root: Option<&Path>,
1240 current_source: Option<&str>,
1241) -> Option<SourceSchema> {
1242 let lock_path = lock_path_for_context(current_file, workspace_root);
1243 if let Ok((source, _diagnostics)) = load_cached_source_for_uri_with_diagnostics(&lock_path, uri)
1244 {
1245 return Some(source);
1246 }
1247
1248 let source = crate::completion::imports::extension_source_schema_via_with_context(
1249 module_name,
1250 source_schema_provider,
1251 uri,
1252 current_file,
1253 workspace_root,
1254 current_source,
1255 )?;
1256
1257 let mut cache = DataSourceSchemaCache::load_or_empty(&lock_path);
1258 cache.upsert_source(source.clone());
1259 let _ = cache.save(&lock_path);
1260
1261 Some(source)
1262}
1263
1264fn lock_path_for_context(current_file: Option<&Path>, workspace_root: Option<&Path>) -> PathBuf {
1265 if let Some(path) = current_file {
1266 if let Some(parent) = path.parent()
1267 && let Some(project) = shape_runtime::project::find_project_root(parent)
1268 {
1269 return project.root_path.join("shape.lock");
1270 }
1271 return path.with_extension("lock");
1272 }
1273
1274 if let Some(root) = workspace_root
1275 && let Some(project) = shape_runtime::project::find_project_root(root)
1276 {
1277 return project.root_path.join("shape.lock");
1278 }
1279
1280 default_cache_path()
1281}
1282
1283fn connection_shape_from_source_schema(source: &SourceSchema) -> String {
1284 let mut tables = source.tables.values().collect::<Vec<_>>();
1285 tables.sort_by(|left, right| left.name.cmp(&right.name));
1286
1287 let fields = tables
1288 .into_iter()
1289 .filter_map(|table| {
1290 if !is_valid_shape_identifier(&table.name) {
1291 return None;
1292 }
1293 Some(format!(
1294 "{}: Table<{}>",
1295 table.name,
1296 row_shape_from_entity_schema(table)
1297 ))
1298 })
1299 .collect::<Vec<_>>();
1300
1301 if fields.is_empty() {
1302 "{}".to_string()
1303 } else {
1304 format!("{{ {} }}", fields.join(", "))
1305 }
1306}
1307
1308fn row_shape_from_entity_schema(entity: &EntitySchema) -> String {
1309 let fields = entity
1310 .columns
1311 .iter()
1312 .filter_map(|column| {
1313 if !is_valid_shape_identifier(&column.name) {
1314 return None;
1315 }
1316 Some(format!(
1317 "{}: {}",
1318 column.name,
1319 schema_column_type(&column.shape_type, column.nullable)
1320 ))
1321 })
1322 .collect::<Vec<_>>();
1323
1324 if fields.is_empty() {
1325 "{}".to_string()
1326 } else {
1327 format!("{{ {} }}", fields.join(", "))
1328 }
1329}
1330
1331fn schema_column_type(shape_type: &str, nullable: bool) -> String {
1332 let base = match shape_type {
1333 "int" => "int",
1334 "number" => "number",
1335 "decimal" => "decimal",
1336 "string" => "string",
1337 "bool" => "bool",
1338 "timestamp" => "timestamp",
1339 _ => "_",
1340 };
1341 if nullable {
1342 format!("Option<{}>", base)
1343 } else {
1344 base.to_string()
1345 }
1346}
1347
1348fn is_valid_shape_identifier(name: &str) -> bool {
1349 let mut chars = name.chars();
1350 let Some(first) = chars.next() else {
1351 return false;
1352 };
1353 if !(first == '_' || first.is_ascii_alphabetic()) {
1354 return false;
1355 }
1356 chars.all(|ch| ch == '_' || ch.is_ascii_alphanumeric())
1357}
1358
1359pub fn infer_variable_type(program: &Program, var_name: &str) -> Option<String> {
1360 let mut finder = VariableFinder {
1361 target_name: var_name,
1362 found_type: None,
1363 found_expr: None,
1364 };
1365 walk_program(&mut finder, program);
1366
1367 if let Some(Expr::Object(entries, _)) = &finder.found_expr {
1368 let mut fields = collect_object_fields(entries);
1369
1370 let assignments = PropertyAssignmentCollector::collect(program);
1371 for assignment in &assignments {
1372 if assignment.variable == var_name
1373 && !fields
1374 .iter()
1375 .any(|(field_name, _)| field_name == &assignment.property)
1376 {
1377 let prop_type = infer_expr_type_via_engine(&assignment.value_expr)
1378 .unwrap_or_else(|| "unknown".to_string());
1379 fields.push((assignment.property.clone(), prop_type));
1380 }
1381 }
1382
1383 return Some(format_object_shape(&fields));
1384 }
1385
1386 finder.found_type
1387}
1388
1389pub fn infer_variable_type_for_display(
1395 program: &Program,
1396 var_name: &str,
1397 offset: usize,
1398) -> Option<String> {
1399 let (visible_fields, masked_fields) =
1400 infer_object_field_state_at_offset(program, var_name, offset)?;
1401 Some(format_object_shape_with_masked_fields(
1402 &visible_fields,
1403 &masked_fields,
1404 ))
1405}
1406
1407pub fn infer_variable_visible_type_at_offset(
1411 program: &Program,
1412 var_name: &str,
1413 offset: usize,
1414) -> Option<String> {
1415 let (visible_fields, _) = infer_object_field_state_at_offset(program, var_name, offset)?;
1416 Some(format_object_shape(&visible_fields))
1417}
1418
1419fn infer_object_field_state_at_offset(
1420 program: &Program,
1421 var_name: &str,
1422 offset: usize,
1423) -> Option<(Vec<(String, String)>, Vec<(String, String)>)> {
1424 let mut finder = VariableFinder {
1425 target_name: var_name,
1426 found_type: None,
1427 found_expr: None,
1428 };
1429 walk_program(&mut finder, program);
1430
1431 let Expr::Object(entries, _) = finder.found_expr.as_ref()? else {
1432 return None;
1433 };
1434
1435 let mut visible_fields = collect_object_fields(entries);
1436 let mut visible_names: HashSet<String> = visible_fields
1437 .iter()
1438 .map(|(name, _)| name.clone())
1439 .collect();
1440
1441 let assignments = PropertyAssignmentCollector::collect(program);
1442 let mut hoisted: Vec<(String, usize, String)> = Vec::new();
1443
1444 for assignment in assignments.iter().filter(|a| a.variable == var_name) {
1445 if visible_names.contains(&assignment.property) {
1446 continue;
1447 }
1448 if hoisted
1449 .iter()
1450 .any(|(existing, _, _)| existing == &assignment.property)
1451 {
1452 continue;
1453 }
1454
1455 let prop_type = infer_expr_type_via_engine(&assignment.value_expr)
1456 .unwrap_or_else(|| "unknown".to_string());
1457 hoisted.push((
1458 assignment.property.clone(),
1459 assignment.assignment_span.start,
1460 prop_type,
1461 ));
1462 }
1463
1464 hoisted.sort_by_key(|(_, assignment_offset, _)| *assignment_offset);
1465
1466 let mut masked_fields = Vec::new();
1467 for (name, assignment_offset, ty) in hoisted {
1468 if assignment_offset <= offset {
1469 visible_names.insert(name.clone());
1470 visible_fields.push((name, ty));
1471 } else {
1472 masked_fields.push((name, ty));
1473 }
1474 }
1475
1476 Some((visible_fields, masked_fields))
1477}
1478
1479fn format_object_shape_with_masked_fields(
1480 visible_fields: &[(String, String)],
1481 masked_fields: &[(String, String)],
1482) -> String {
1483 if masked_fields.is_empty() {
1484 return format_object_shape(visible_fields);
1485 }
1486
1487 let visible = visible_fields
1488 .iter()
1489 .map(|(name, ty)| format!("{}: {}", name, ty))
1490 .collect::<Vec<_>>()
1491 .join(", ");
1492 let masked = masked_fields
1493 .iter()
1494 .map(|(name, ty)| format!("{}: {}", name, ty))
1495 .collect::<Vec<_>>()
1496 .join(", ");
1497
1498 if visible.is_empty() {
1499 format!("{{ /* {} */ }}", masked)
1500 } else {
1501 format!("{{ {} /*, {} */ }}", visible, masked)
1502 }
1503}
1504
1505fn collect_object_fields(entries: &[ObjectEntry]) -> Vec<(String, String)> {
1506 let mut fields = Vec::new();
1507 for entry in entries {
1508 if let ObjectEntry::Field {
1509 key,
1510 value,
1511 type_annotation,
1512 } = entry
1513 {
1514 let field_type = if let Some(type_ann) = type_annotation {
1515 type_annotation_to_string(type_ann).unwrap_or_else(|| "unknown".to_string())
1516 } else {
1517 infer_expr_type_via_engine(value).unwrap_or_else(|| "unknown".to_string())
1518 };
1519 fields.push((key.clone(), field_type));
1520 }
1521 }
1522 fields
1523}
1524
1525struct VariableFinder<'a> {
1526 target_name: &'a str,
1527 found_type: Option<String>,
1528 found_expr: Option<Expr>,
1529}
1530
1531impl<'a> Visitor for VariableFinder<'a> {
1532 fn visit_item(&mut self, item: &Item) -> bool {
1533 if let Item::VariableDecl(decl, _) = item {
1534 self.check_variable_decl(decl);
1535 }
1536 true
1537 }
1538
1539 fn visit_stmt(&mut self, stmt: &Statement) -> bool {
1540 if let Statement::VariableDecl(decl, _) = stmt {
1541 self.check_variable_decl(decl);
1542 }
1543 true
1544 }
1545}
1546
1547impl<'a> VariableFinder<'a> {
1548 fn check_variable_decl(&mut self, decl: &VariableDecl) {
1549 if let Some(name) = decl.pattern.as_identifier() {
1550 if name == self.target_name {
1551 if let Some(value) = &decl.value {
1552 self.found_expr = Some(value.clone());
1553 }
1554
1555 if let Some(type_ann) = &decl.type_annotation {
1556 self.found_type = type_annotation_to_string(type_ann);
1557 } else if let Some(value) = &decl.value {
1558 self.found_type = infer_expr_type_via_engine(value);
1559 }
1560 }
1561 }
1562 }
1563}
1564
1565#[derive(Debug, Clone)]
1567pub struct MethodCompletionInfo {
1568 pub name: String,
1569 pub signature: Option<String>,
1570 pub from_trait: Option<String>,
1571 pub documentation: Option<String>,
1572}
1573
1574pub fn extract_type_methods(program: &Program) -> HashMap<String, Vec<MethodCompletionInfo>> {
1580 let augmented = shape_ast::transform::augment_program_with_generated_extends(program);
1581 let mut result: HashMap<String, Vec<MethodCompletionInfo>> = HashMap::new();
1582
1583 let mut trait_methods: HashMap<String, Vec<MethodCompletionInfo>> = HashMap::new();
1585 for item in &augmented.items {
1586 if let Item::Trait(trait_def, _) = item {
1587 let methods: Vec<MethodCompletionInfo> = trait_def
1588 .members
1589 .iter()
1590 .filter_map(|member| match member {
1591 TraitMember::Required(
1592 im @ InterfaceMember::Method {
1593 name,
1594 params,
1595 return_type,
1596 ..
1597 },
1598 ) => {
1599 let param_names: Vec<String> = params
1600 .iter()
1601 .map(|p| p.name.clone().unwrap_or_else(|| "_".to_string()))
1602 .collect();
1603 let sig = format!(
1604 "{}({}): {}",
1605 name,
1606 param_names.join(", "),
1607 type_annotation_to_string(return_type)
1608 .unwrap_or_else(|| "_".to_string())
1609 );
1610 Some(MethodCompletionInfo {
1611 name: name.clone(),
1612 signature: Some(sig),
1613 from_trait: Some(trait_def.name.clone()),
1614 documentation: interface_member_doc(im),
1615 })
1616 }
1617 _ => None,
1618 })
1619 .collect();
1620 trait_methods.insert(trait_def.name.clone(), methods);
1621 }
1622 }
1623
1624 for item in &augmented.items {
1626 match item {
1627 Item::Impl(impl_block, _) => {
1628 let target_type = match &impl_block.target_type {
1629 shape_ast::ast::TypeName::Simple(name) => name.to_string(),
1630 shape_ast::ast::TypeName::Generic { name, .. } => name.to_string(),
1631 };
1632 let trait_name = match &impl_block.trait_name {
1633 shape_ast::ast::TypeName::Simple(name) => name.to_string(),
1634 shape_ast::ast::TypeName::Generic { name, .. } => name.to_string(),
1635 };
1636
1637 if let Some(trait_meths) = trait_methods.get(&trait_name) {
1639 let entry = result.entry(target_type.clone()).or_default();
1640 for m in trait_meths {
1641 if !entry.iter().any(|existing| existing.name == m.name) {
1643 entry.push(m.clone());
1644 }
1645 }
1646 }
1647
1648 let entry = result.entry(target_type).or_default();
1651 for method in &impl_block.methods {
1652 if !entry.iter().any(|existing| existing.name == method.name) {
1653 let sig = format!(
1654 "{}({})",
1655 method.name,
1656 method
1657 .params
1658 .iter()
1659 .map(|p| p.simple_name().unwrap_or("_").to_string())
1660 .collect::<Vec<_>>()
1661 .join(", ")
1662 );
1663 entry.push(MethodCompletionInfo {
1664 name: method.name.clone(),
1665 signature: Some(sig),
1666 from_trait: Some(trait_name.clone()),
1667 documentation: method_doc(method.doc_comment.as_ref()),
1668 });
1669 }
1670 }
1671 }
1672 Item::Extend(extend, _) => {
1673 let type_name = match &extend.type_name {
1674 shape_ast::ast::TypeName::Simple(name) => name.to_string(),
1675 shape_ast::ast::TypeName::Generic { name, .. } => name.to_string(),
1676 };
1677 let entry = result.entry(type_name).or_default();
1678 for method in &extend.methods {
1679 if !entry.iter().any(|existing| existing.name == method.name) {
1680 let sig = format!(
1681 "{}({})",
1682 method.name,
1683 method
1684 .params
1685 .iter()
1686 .map(|p| p.simple_name().unwrap_or("_").to_string())
1687 .collect::<Vec<_>>()
1688 .join(", ")
1689 );
1690 entry.push(MethodCompletionInfo {
1691 name: method.name.clone(),
1692 signature: Some(sig),
1693 from_trait: None,
1694 documentation: method_doc(method.doc_comment.as_ref()),
1695 });
1696 }
1697 }
1698 }
1699 _ => {}
1700 }
1701 }
1702
1703 result
1704}
1705
1706fn interface_member_doc(member: &InterfaceMember) -> Option<String> {
1707 match member {
1708 InterfaceMember::Method { doc_comment, .. }
1709 | InterfaceMember::Property { doc_comment, .. }
1710 | InterfaceMember::IndexSignature { doc_comment, .. } => method_doc(doc_comment.as_ref()),
1711 }
1712}
1713
1714fn method_doc(doc_comment: Option<&shape_ast::ast::DocComment>) -> Option<String> {
1715 let comment = doc_comment?;
1716 if !comment.body.is_empty() {
1717 Some(comment.body.clone())
1718 } else if !comment.summary.is_empty() {
1719 Some(comment.summary.clone())
1720 } else {
1721 None
1722 }
1723}
1724
1725pub fn simplify_result_type(ty: &str) -> String {
1728 let Some(inner) = ty.strip_prefix("Result<").and_then(|s| s.strip_suffix('>')) else {
1729 return ty.to_string();
1730 };
1731 let mut depth = 0;
1733 for (i, ch) in inner.char_indices() {
1734 match ch {
1735 '<' => depth += 1,
1736 '>' => depth -= 1,
1737 ',' if depth == 0 => {
1738 let ok_type = inner[..i].trim();
1739 return format!("Result<{}>", ok_type);
1740 }
1741 _ => {}
1742 }
1743 }
1744 ty.to_string()
1745}
1746
1747#[cfg(test)]
1748mod tests {
1749 use super::*;
1750 use shape_ast::parser::parse_program;
1751
1752 #[test]
1753 fn test_extract_struct_fields_from_literal_no_type_def() {
1754 let code =
1756 "let b: MyType = MyType { i: 10.2D }\nmeta MyType {\n format: |v| v.i.toString()\n}\n";
1757 let program = parse_program(code).unwrap();
1758 let fields = extract_struct_fields(&program);
1759 let my_type = fields
1760 .get("MyType")
1761 .expect("Should find MyType from struct literal");
1762 assert_eq!(my_type[0], ("i".to_string(), "decimal".to_string()));
1763 }
1764
1765 #[test]
1766 fn test_extract_struct_fields_type_def_takes_precedence() {
1767 let code = "type MyType { i: int }\nlet b = MyType { i: 10.2D }\n";
1769 let program = parse_program(code).unwrap();
1770 let fields = extract_struct_fields(&program);
1771 let my_type = fields.get("MyType").expect("Should find MyType");
1772 assert_eq!(my_type[0], ("i".to_string(), "int".to_string()));
1774 }
1775
1776 #[test]
1777 fn test_infer_literal_type_formatted_string() {
1778 let ty = infer_literal_type(&Literal::FormattedString {
1779 value: "x={x}".to_string(),
1780 mode: shape_ast::ast::InterpolationMode::Braces,
1781 });
1782 assert_eq!(ty, "string");
1783 }
1784
1785 #[test]
1786 fn test_infer_program_types_basic() {
1787 let code = "let x = 42\nlet s = \"hello\"\nlet b = true";
1788 let program = parse_program(code).unwrap();
1789 let types = infer_program_types(&program);
1790 assert_eq!(types.get("x").map(|s| s.as_str()), Some("int"));
1791 assert_eq!(types.get("s").map(|s| s.as_str()), Some("string"));
1792 assert_eq!(types.get("b").map(|s| s.as_str()), Some("bool"));
1793 }
1794
1795 #[test]
1796 fn test_infer_program_types_includes_hoisted_object_fields() {
1797 let code = "let a = { x: 1 }\na.y = 2\n";
1798 let program = parse_program(code).unwrap();
1799 let types = infer_program_types(&program);
1800 let a_type = types.get("a").expect("a should have inferred type");
1801 assert!(
1802 a_type.contains("x: int") && a_type.contains("y: int"),
1803 "expected hoisted field in object type, got {}",
1804 a_type
1805 );
1806 }
1807
1808 #[test]
1809 fn test_infer_program_types_connection_uses_cached_schema_tables() {
1810 use shape_runtime::schema_cache::{
1811 DataSourceSchemaCache, EntitySchema, FieldSchema, SourceSchema, set_default_cache_path,
1812 };
1813 use std::collections::HashMap;
1814
1815 struct CachePathReset;
1816 impl Drop for CachePathReset {
1817 fn drop(&mut self) {
1818 set_default_cache_path(None);
1819 }
1820 }
1821
1822 let tmp = tempfile::tempdir().unwrap();
1823 let cache_path = tmp.path().join("shape.lock");
1824
1825 let mut cache = DataSourceSchemaCache::new();
1826 cache.upsert_source(SourceSchema {
1827 uri: "duckdb://analytics.db".to_string(),
1828 cached_at: "2026-02-17T00:00:00Z".to_string(),
1829 tables: HashMap::from([(
1830 "candles".to_string(),
1831 EntitySchema {
1832 name: "candles".to_string(),
1833 columns: vec![
1834 FieldSchema {
1835 name: "open".to_string(),
1836 shape_type: "number".to_string(),
1837 nullable: false,
1838 },
1839 FieldSchema {
1840 name: "volume".to_string(),
1841 shape_type: "int".to_string(),
1842 nullable: true,
1843 },
1844 ],
1845 },
1846 )]),
1847 });
1848 cache.save(&cache_path).unwrap();
1849
1850 set_default_cache_path(Some(cache_path));
1851 let _reset = CachePathReset;
1852
1853 let program =
1854 parse_program(r#"let conn = duckdb.connect("duckdb://analytics.db")"#).unwrap();
1855 let types = infer_program_types(&program);
1856 let conn_type = types.get("conn").expect("conn type should be inferred");
1857
1858 assert!(
1859 conn_type.contains("candles: Table<{ open: number"),
1860 "expected candles table in connection shape, got {}",
1861 conn_type
1862 );
1863 assert!(
1864 conn_type.contains("volume: Option<int>"),
1865 "expected nullable column mapped to Option<int>, got {}",
1866 conn_type
1867 );
1868 }
1869
1870 #[test]
1871 fn test_lock_path_for_context_prefers_script_lock_for_standalone_files() {
1872 let tmp = tempfile::tempdir().unwrap();
1873 let script_path = tmp.path().join("demo.shape");
1874 let expected = tmp.path().join("demo.lock");
1875 let actual = lock_path_for_context(Some(&script_path), None);
1876 assert_eq!(actual, expected);
1877 }
1878
1879 #[test]
1880 fn test_infer_program_types_with_context_uses_script_lock() {
1881 use shape_runtime::schema_cache::{
1882 DataSourceSchemaCache, EntitySchema, FieldSchema, SourceSchema,
1883 };
1884 use std::collections::HashMap;
1885
1886 let tmp = tempfile::tempdir().unwrap();
1887 let script_path = tmp.path().join("demo.shape");
1888 let lock_path = tmp.path().join("demo.lock");
1889
1890 let mut cache = DataSourceSchemaCache::new();
1891 cache.upsert_source(SourceSchema {
1892 uri: "duckdb://analytics.db".to_string(),
1893 cached_at: "2026-02-18T00:00:00Z".to_string(),
1894 tables: HashMap::from([(
1895 "candles".to_string(),
1896 EntitySchema {
1897 name: "candles".to_string(),
1898 columns: vec![FieldSchema {
1899 name: "open".to_string(),
1900 shape_type: "number".to_string(),
1901 nullable: false,
1902 }],
1903 },
1904 )]),
1905 });
1906 cache.save(&lock_path).unwrap();
1907
1908 let source = r#"let conn = duckdb.connect("duckdb://analytics.db")"#;
1909 let program = parse_program(source).unwrap();
1910 let types =
1911 infer_program_types_with_context(&program, Some(&script_path), None, Some(source));
1912 let conn_type = types.get("conn").expect("conn type should be inferred");
1913 assert!(
1914 conn_type.contains("candles: Table<{ open: number }>"),
1915 "expected candles table inferred from script lock, got {}",
1916 conn_type
1917 );
1918 }
1919
1920 #[test]
1921 fn test_infer_expr_type_via_engine_match() {
1922 let code = "match 1 { 1 => true, 2 => false }";
1923 let program = parse_program(code).unwrap();
1924 if let Some(shape_ast::ast::Item::Statement(
1925 shape_ast::ast::Statement::Expression(expr, _),
1926 _,
1927 )) = program.items.first()
1928 {
1929 let ty = infer_expr_type_via_engine(expr);
1930 assert!(
1931 ty.is_some(),
1932 "Engine should infer type for match expression"
1933 );
1934 let ty_str = ty.unwrap();
1935 assert!(
1936 ty_str.contains("bool"),
1937 "Match with all bool arms should be bool, got: {}",
1938 ty_str
1939 );
1940 }
1941 }
1942
1943 #[test]
1944 fn test_infer_expr_type_via_engine_match_union() {
1945 let code = "match 1 { 1 => true, 2 => \"hello\" }";
1946 let program = parse_program(code).unwrap();
1947 if let Some(shape_ast::ast::Item::Statement(
1948 shape_ast::ast::Statement::Expression(expr, _),
1949 _,
1950 )) = program.items.first()
1951 {
1952 let ty = infer_expr_type_via_engine(expr);
1953 assert!(
1954 ty.is_some(),
1955 "Engine should infer type for match with mixed arms"
1956 );
1957 let ty_str = ty.unwrap();
1958 assert!(
1959 ty_str.contains("bool") && ty_str.contains("string"),
1960 "Should be union of bool and string, got: {}",
1961 ty_str
1962 );
1963 }
1964 }
1965
1966 #[test]
1967 fn test_infer_expr_type_match_typed_pattern_numeric_branch_stays_int() {
1968 let code = "let result = match value {\n c: int => c + 1\n _ => 1\n}\n";
1969 let program = parse_program(code).unwrap();
1970 let expr = match program.items.first() {
1971 Some(shape_ast::ast::Item::VariableDecl(decl, _)) => {
1972 decl.value.as_ref().expect("result should have value")
1973 }
1974 Some(shape_ast::ast::Item::Statement(
1975 shape_ast::ast::Statement::VariableDecl(decl, _),
1976 _,
1977 )) => decl.value.as_ref().expect("result should have value"),
1978 other => panic!("expected variable declaration, got {:?}", other),
1979 };
1980
1981 assert_eq!(infer_expr_type(expr).as_deref(), Some("int"));
1982 }
1983
1984 #[test]
1985 fn test_infer_program_types_match_variable() {
1986 let code = "let test = match 2 {\n 0 => true,\n _ => false,\n}";
1987 let program = parse_program(code).unwrap();
1988 let types = infer_program_types(&program);
1989 eprintln!("infer_program_types result: {:?}", types);
1990 assert_eq!(
1991 types.get("test").map(|s| s.as_str()),
1992 Some("bool"),
1993 "test should be inferred as bool from match expression, got: {:?}",
1994 types.get("test")
1995 );
1996 }
1997
1998 #[test]
1999 fn test_type_to_string_concrete() {
2000 let ty = Type::Concrete(TypeAnnotation::Basic("int".to_string()));
2001 assert_eq!(type_to_string(&ty), "int");
2002 }
2003
2004 #[test]
2005 fn test_type_to_string_union() {
2006 let ty = Type::Concrete(TypeAnnotation::Union(vec![
2007 TypeAnnotation::Basic("bool".to_string()),
2008 TypeAnnotation::Basic("string".to_string()),
2009 ]));
2010 assert_eq!(type_to_string(&ty), "bool | string");
2011 }
2012
2013 #[test]
2014 fn test_infer_method_call_type_preserving() {
2015 use shape_ast::ast::{Expr, Span};
2017 let receiver = Box::new(Expr::Array(
2018 vec![
2019 Expr::Literal(Literal::Int(1), Span::default()),
2020 Expr::Literal(Literal::Int(2), Span::default()),
2021 ],
2022 Span::default(),
2023 ));
2024 let expr = Expr::MethodCall {
2025 receiver,
2026 method: "filter".to_string(),
2027 args: vec![],
2028 named_args: vec![],
2029 optional: false,
2030 span: Span::default(),
2031 };
2032 let ty = infer_expr_type(&expr);
2033 assert_eq!(ty, Some("int[]".to_string()), "filter should preserve type");
2034 }
2035
2036 #[test]
2037 fn test_infer_method_call_aggregation() {
2038 use shape_ast::ast::{Expr, Span};
2039 let receiver = Box::new(Expr::Array(vec![], Span::default()));
2040 let expr = Expr::MethodCall {
2041 receiver,
2042 method: "sum".to_string(),
2043 args: vec![],
2044 named_args: vec![],
2045 optional: false,
2046 span: Span::default(),
2047 };
2048 assert_eq!(
2049 infer_expr_type(&expr),
2050 Some("number".to_string()),
2051 "sum() should return number"
2052 );
2053 }
2054
2055 #[test]
2056 fn test_infer_method_call_chained() {
2057 use shape_ast::ast::{Expr, Span};
2058 let array = Box::new(Expr::Array(
2059 vec![Expr::Literal(Literal::Int(1), Span::default())],
2060 Span::default(),
2061 ));
2062 let filtered = Box::new(Expr::MethodCall {
2063 receiver: array,
2064 method: "filter".to_string(),
2065 args: vec![],
2066 named_args: vec![],
2067 optional: false,
2068 span: Span::default(),
2069 });
2070 let reversed = Expr::MethodCall {
2071 receiver: filtered,
2072 method: "reverse".to_string(),
2073 args: vec![],
2074 named_args: vec![],
2075 optional: false,
2076 span: Span::default(),
2077 };
2078 let ty = infer_expr_type(&reversed);
2079 assert_eq!(
2080 ty,
2081 Some("int[]".to_string()),
2082 "chained filter.reverse should preserve type"
2083 );
2084 }
2085
2086 #[test]
2087 fn test_infer_method_call_unwrap() {
2088 use shape_ast::ast::{Expr, Span};
2089 let receiver = Box::new(Expr::TypeAssertion {
2090 expr: Box::new(Expr::Identifier("x".to_string(), Span::default())),
2091 type_annotation: TypeAnnotation::Generic {
2092 name: "Result".into(),
2093 args: vec![TypeAnnotation::Basic("Foo".to_string())],
2094 },
2095 meta_param_overrides: None,
2096 span: Span::default(),
2097 });
2098 let expr = Expr::MethodCall {
2099 receiver,
2100 method: "unwrap".to_string(),
2101 args: vec![],
2102 named_args: vec![],
2103 optional: false,
2104 span: Span::default(),
2105 };
2106 assert_eq!(
2107 infer_expr_type(&expr),
2108 Some("Foo".to_string()),
2109 "unwrap on Result<Foo> should return Foo"
2110 );
2111 }
2112
2113 #[test]
2114 fn test_extract_type_methods_extend_block() {
2115 let code = "extend Foo {\n method bar() {\n self\n }\n}\n";
2116 let program = parse_program(code).unwrap();
2117 let methods = extract_type_methods(&program);
2118 let foo_methods = methods.get("Foo").expect("Should find Foo methods");
2119 assert!(
2120 foo_methods.iter().any(|m| m.name == "bar"),
2121 "Should include 'bar' method from extend block"
2122 );
2123 }
2124
2125 #[test]
2126 fn test_extract_type_methods_from_annotation_comptime_extend_target() {
2127 let code = r#"
2128annotation add_sum() {
2129 targets: [type]
2130 comptime post(target, ctx) {
2131 extend target {
2132 method sum() { self.x + self.y }
2133 }
2134 }
2135}
2136@add_sum()
2137type Point { x: int, y: int }
2138"#;
2139 let program = parse_program(code).unwrap();
2140 let methods = extract_type_methods(&program);
2141 let point_methods = methods.get("Point").expect("Should find Point methods");
2142 assert!(
2143 point_methods.iter().any(|m| m.name == "sum"),
2144 "Should include generated 'sum' method from annotation comptime handler"
2145 );
2146 }
2147
2148 #[test]
2149 fn test_extract_type_methods_from_annotation_comptime_extend_explicit_type() {
2150 let code = r#"
2151annotation add_number_method() {
2152 targets: [function]
2153 comptime post(target, ctx) {
2154 extend Number {
2155 method doubled() { self * 2.0 }
2156 }
2157 }
2158}
2159@add_number_method()
2160fn marker() { 0 }
2161"#;
2162 let program = parse_program(code).unwrap();
2163 let methods = extract_type_methods(&program);
2164 let number_methods = methods.get("Number").expect("Should find Number methods");
2165 assert!(
2166 number_methods.iter().any(|m| m.name == "doubled"),
2167 "Should include generated 'doubled' method on Number"
2168 );
2169 }
2170
2171 #[test]
2172 fn test_extract_type_methods_annotation_not_applied_does_not_generate() {
2173 let code = r#"
2174annotation add_number_method() {
2175 targets: [function]
2176 comptime post(target, ctx) {
2177 extend Number {
2178 method doubled() { self * 2.0 }
2179 }
2180 }
2181}
2182type Point { x: int, y: int }
2183"#;
2184 let program = parse_program(code).unwrap();
2185 let methods = extract_type_methods(&program);
2186 assert!(
2187 !methods.contains_key("Number"),
2188 "Annotation definition without usage must not generate methods"
2189 );
2190 }
2191
2192 #[test]
2193 fn test_extract_type_methods_impl_block() {
2194 let code = r#"
2195trait Queryable {
2196 filter(pred): any;
2197 select(cols): any;
2198 orderBy(col): any
2199}
2200impl Queryable for MyQ {
2201 method filter(pred) { self }
2202}
2203"#;
2204 let program = parse_program(code).unwrap();
2205 let methods = extract_type_methods(&program);
2206 let myq_methods = methods.get("MyQ").expect("Should find MyQ methods");
2207 let names: Vec<&str> = myq_methods.iter().map(|m| m.name.as_str()).collect();
2208 assert!(names.contains(&"filter"), "Should include filter");
2210 assert!(names.contains(&"select"), "Should include select");
2211 assert!(names.contains(&"orderBy"), "Should include orderBy");
2212 }
2213
2214 #[test]
2215 fn test_extract_type_methods_trait_only() {
2216 let code = "trait Foo {\n bar(): any\n}\n";
2218 let program = parse_program(code).unwrap();
2219 let methods = extract_type_methods(&program);
2220 assert!(
2221 methods.is_empty(),
2222 "Trait alone should not produce type methods"
2223 );
2224 }
2225
2226 #[test]
2227 fn test_extract_type_methods_multiple_impls() {
2228 let code = r#"
2229trait A { a1(): any }
2230trait B { b1(): any }
2231impl A for X { method a1() { self } }
2232impl B for X { method b1() { self } }
2233"#;
2234 let program = parse_program(code).unwrap();
2235 let methods = extract_type_methods(&program);
2236 let x_methods = methods.get("X").expect("Should find X methods");
2237 let names: Vec<&str> = x_methods.iter().map(|m| m.name.as_str()).collect();
2238 assert!(names.contains(&"a1"), "Should include a1 from trait A");
2239 assert!(names.contains(&"b1"), "Should include b1 from trait B");
2240 }
2241
2242 #[test]
2243 fn test_infer_function_signatures_return_type() {
2244 let code = "fn add(a: int, b: int) {\n return a + b\n}";
2245 let program = parse_program(code).unwrap();
2246 let sigs = infer_function_signatures(&program);
2247 if let Some(info) = sigs.get("add") {
2248 assert!(
2250 info.param_types.is_empty(),
2251 "Annotated params should not appear: {:?}",
2252 info.param_types
2253 );
2254 assert!(
2256 info.return_type.is_some(),
2257 "Return type should be inferred from body"
2258 );
2259 }
2260 }
2263
2264 #[test]
2265 fn test_infer_function_signatures_unannotated_param_union_from_callsites() {
2266 let code = "fn foo(a) {\n return a\n}\nlet i = foo(1)\nlet s = foo(\"hi\")\n";
2267 let program = parse_program(code).unwrap();
2268 let sigs = infer_function_signatures(&program);
2269 let info = sigs.get("foo").expect("foo should have inferred signature");
2270 let param = info
2271 .param_types
2272 .iter()
2273 .find(|(name, _)| name == "a")
2274 .expect("expected inferred type for param a");
2275 assert!(
2276 param.1.contains("int") && param.1.contains("string"),
2277 "expected union param type, got {}",
2278 param.1
2279 );
2280 let ret = info.return_type.as_deref().unwrap_or("");
2281 assert!(
2282 ret.contains("int") && ret.contains("string"),
2283 "expected union return type, got {}",
2284 ret
2285 );
2286 assert!(
2287 matches!(
2288 info.param_ref_modes.get("a"),
2289 Some(ParamReferenceMode::Shared)
2290 ),
2291 "expected read-only inferred reference mode for union param"
2292 );
2293 }
2294
2295 #[test]
2296 fn test_infer_function_signatures_marks_mutating_ref_params() {
2297 let code = r#"
2298fn mutate(a) {
2299 a = "new"
2300 return a
2301}
2302let s = "old"
2303mutate(s)
2304"#;
2305 let program = parse_program(code).unwrap();
2306 let sigs = infer_function_signatures(&program);
2307 let info = sigs
2308 .get("mutate")
2309 .expect("mutate should have inferred signature");
2310 assert!(
2311 matches!(
2312 info.param_ref_modes.get("a"),
2313 Some(ParamReferenceMode::Exclusive)
2314 ),
2315 "expected mutating inferred reference mode"
2316 );
2317 }
2318
2319 #[test]
2320 fn test_infer_function_signatures_skips_annotated() {
2321 let code = "fn greet(name: string) -> string {\n return name\n}";
2322 let program = parse_program(code).unwrap();
2323 let sigs = infer_function_signatures(&program);
2324 assert!(
2326 sigs.get("greet").is_none(),
2327 "Fully annotated function should have no inferred signatures"
2328 );
2329 }
2330}