1use super::{Compiler, FnInferRet, ListElemState, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, Pattern, PatternKind, Span, Stmt, StmtKind, UnaryOp};
5use smol_str::SmolStr;
6
7#[derive(Clone)]
8struct ReturnInfo {
9 ty: Type,
10 shape: Option<Type>,
11}
12
13impl Compiler {
14 fn current_infer_key(&self) -> Option<(u32, Vec<Type>, Vec<Type>)> {
15 self.infer_stack.last().cloned()
16 }
17
18 fn pending_return_seed(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Option<Type> {
19 self.fns.get(&id).and_then(|fns| {
20 fns.iter().find_map(|item| {
21 if item.0 == generic_args
22 && item.1 == fn_tys
23 && let FnInferRet::Pending(seed) = &item.2
24 {
25 seed.clone()
26 } else {
27 None
28 }
29 })
30 })
31 }
32
33 fn update_pending_return_seed(&mut self, ty: &Type) {
34 if ty.is_any() {
35 return;
36 }
37 let Some((id, generic_args, fn_tys)) = self.current_infer_key() else {
38 return;
39 };
40 let Some(fns) = self.fns.get_mut(&id) else {
41 return;
42 };
43 if let Some(item) = fns.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys)
44 && let FnInferRet::Pending(seed) = &mut item.2
45 {
46 let next = seed.take().map(|prev| prev + ty.clone()).unwrap_or_else(|| ty.clone());
47 *seed = Some(next);
48 }
49 }
50
51 fn try_find_base_return_ty(&self, body: &Stmt) -> Option<Type> {
53 match &body.kind {
54 StmtKind::Block(stmts) => stmts.iter().find_map(|s| self.try_find_base_return_ty(s)),
55 StmtKind::If { then_body, else_body, .. } => self.try_find_base_return_ty(then_body)
56 .or_else(|| else_body.as_ref().and_then(|b| self.try_find_base_return_ty(b))),
57 StmtKind::Return(Some(expr)) => Self::try_literal_type(expr),
58 StmtKind::Expr(expr, false) => Self::try_literal_type(expr),
59 _ => None,
60 }
61 }
62
63 fn try_find_base_return_ty_with_scope(&mut self, body: &Stmt, fn_id: u32, fn_name: &str, args: &[SmolStr], fn_tys: &[Type]) -> Option<Type> {
65 let saved_state = self.take_local_state();
66 self.frames.push(0);
67 for (arg, ty) in args.iter().zip(fn_tys.iter()) {
68 self.add_name(arg.clone());
69 self.add_ty(ty.clone());
70 }
71 let result = self.try_find_base_return_ty_with_scope_inner(body, fn_id, fn_name);
72 self.restore_local_state(saved_state);
73 result
74 }
75
76 fn try_find_base_return_ty_with_scope_inner(&mut self, body: &Stmt, fn_id: u32, fn_name: &str) -> Option<Type> {
77 match &body.kind {
78 StmtKind::Block(stmts) => stmts.iter().find_map(|s| self.try_find_base_return_ty_with_scope_inner(s, fn_id, fn_name)),
79 StmtKind::If { then_body, else_body, .. } => self.try_find_base_return_ty_with_scope_inner(then_body, fn_id, fn_name)
80 .or_else(|| else_body.as_ref().and_then(|b| self.try_find_base_return_ty_with_scope_inner(b, fn_id, fn_name))),
81 StmtKind::Return(Some(expr)) => {
82 if Self::expr_calls_fn(expr, fn_id, fn_name) { None }
83 else { self.infer_return_expr(expr).ok().map(|info| info.ty) }
84 }
85 StmtKind::Expr(expr, false) => {
86 if Self::expr_calls_fn(expr, fn_id, fn_name) { None }
87 else { self.infer_return_expr(expr).ok().map(|info| info.ty) }
88 }
89 _ => None,
90 }
91 }
92
93 fn expr_calls_fn(expr: &Expr, fn_id: u32, fn_name: &str) -> bool {
94 match &expr.kind {
95 ExprKind::Call { obj, params } => {
96 if let ExprKind::Id(id, _) = &obj.kind { return *id == fn_id; }
97 if let ExprKind::Ident(name) = &obj.kind {
98 if name.as_str() == fn_name || fn_name.ends_with(&format!("::{}", name)) { return true; }
99 }
100 params.iter().any(|p| Self::expr_calls_fn(p, fn_id, fn_name))
101 }
102 ExprKind::Binary { left, op: _, right } => Self::expr_calls_fn(left, fn_id, fn_name) || Self::expr_calls_fn(right, fn_id, fn_name),
103 ExprKind::Unary { op: _, value } => Self::expr_calls_fn(value, fn_id, fn_name),
104 ExprKind::Typed { value, ty: _ } => Self::expr_calls_fn(value, fn_id, fn_name),
105 _ => false,
106 }
107 }
108
109 fn try_literal_type(expr: &Expr) -> Option<Type> {
110 match &expr.kind {
111 ExprKind::Value(v) => Some(v.get_type()),
112 ExprKind::Unary { op: UnaryOp::Neg, value } => Self::try_literal_type(value),
113 _ => None,
114 }
115 }
116
117 fn add_pattern_bindings_for_infer(&mut self, pat: &Pattern, expr_ty: Type) -> Result<()> {
118 match &pat.kind {
119 PatternKind::Ident { name, ty } => {
120 let annotated_ty = self.symbols.get_type(ty)?;
121 self.add_name(name.clone());
122 self.add_ty(if annotated_ty.is_any() { expr_ty } else { annotated_ty });
123 }
124 PatternKind::Var { idx, .. } => self.set_ty(*idx, expr_ty),
125 PatternKind::Tuple(pats) => {
126 if let Type::Tuple(tys) = expr_ty {
127 for (pat, ty) in pats.iter().zip(tys) {
128 self.add_pattern_bindings_for_infer(pat, ty)?;
129 }
130 } else {
131 for pat in pats {
132 self.add_pattern_bindings_for_infer(pat, Type::Any)?;
133 }
134 }
135 }
136 PatternKind::List { elems, .. } => {
137 for pat in elems {
138 self.add_pattern_bindings_for_infer(pat, Type::Any)?;
139 }
140 }
141 PatternKind::Wildcard => {
142 self.add_name("".into());
143 self.add_ty(expr_ty);
144 }
145 PatternKind::Literal(_) | PatternKind::Member(_, _) | PatternKind::Idx(_, _) => {}
146 }
147 Ok(())
148 }
149
150 fn for_pattern_ty(&mut self, range: &Expr) -> Result<Type> {
151 if matches!(range.kind, ExprKind::Range { .. }) {
152 return self.infer_range_expr(range);
153 }
154 Ok(match self.infer_expr(range)? {
155 Type::Array(elem_ty, _) | Type::Vec(elem_ty, _) | Type::List(elem_ty) => elem_ty.as_ref().clone(),
156 _ => Type::Any,
157 })
158 }
159
160 fn infer_range_expr(&mut self, range: &Expr) -> Result<Type> {
161 let ExprKind::Range { start, stop, .. } = &range.kind else {
162 return self.infer_expr(range);
163 };
164 let start_ty = self.infer_expr(start)?;
165 let stop_ty = self.infer_expr(stop)?;
166 Ok(Self::merge_range_bound_types(start_ty, stop_ty))
167 }
168
169 fn merge_range_bound_types(start_ty: Type, stop_ty: Type) -> Type {
170 if start_ty.is_any() {
171 stop_ty
172 } else if stop_ty.is_any() {
173 start_ty
174 } else if start_ty == Type::I32 && stop_ty.is_uint() {
175 stop_ty
176 } else if stop_ty == Type::I32 && start_ty.is_uint() {
177 start_ty
178 } else {
179 start_ty + stop_ty
180 }
181 }
182
183 fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
184 match left {
185 Some(left) if left == right => Ok(left),
186 Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
187 Some(left) if left.is_any() || right.is_any() => Ok(Type::Any),
188 Some(left) => Ok(left + right),
189 None => Ok(right),
190 }
191 }
192
193 fn return_shape(&self, expr: &Expr, ty: &Type) -> Option<Type> {
194 if !ty.is_any() {
195 return match ty {
196 Type::Struct { .. } => Some(ty.clone()),
197 Type::Map => Some(Type::Map),
198 Type::List(elem) | Type::Array(elem, _) => Some(Type::List(elem.clone())),
199 _ => None,
200 };
201 }
202 match &expr.kind {
203 ExprKind::List(_) | ExprKind::Tuple(_) => Some(Type::list_any()),
204 ExprKind::Dict(_) => Some(Type::Map),
205 ExprKind::Value(value) => Self::dynamic_return_shape(value.get_type()),
206 ExprKind::Const(idx) => self.consts.get(*idx).and_then(|value| Self::dynamic_return_shape(value.get_type())),
207 ExprKind::Typed { ty, .. } => Some(ty.clone()),
208 _ => None,
209 }
210 }
211
212 fn dynamic_return_shape(ty: Type) -> Option<Type> {
213 match ty {
214 Type::Map => Some(Type::Map),
215 Type::List(elem) => Some(Type::List(elem)),
216 Type::Array(elem, _) => Some(Type::List(elem)),
217 _ => None,
218 }
219 }
220
221 fn local_var_idx_for_expr(&self, expr: &Expr) -> Option<u32> {
222 match &expr.kind {
223 ExprKind::Var(idx) => Some(*idx),
224 ExprKind::Ident(name) => (self.top()..self.names.len()).rev().find(|idx| self.names[*idx].eq(name)).map(|idx| (idx - self.top()) as u32),
225 _ => None,
226 }
227 }
228
229 fn infer_list_method(&mut self, target: &Expr, elem_ty: &Type, method: &str, params: &[Expr]) -> Result<Option<Type>> {
230 match method {
231 "get_idx" | "pop" => Ok(Some(match self.local_var_idx_for_expr(target).and_then(|idx| self.list_elem_state(idx)) {
232 Some(ListElemState::Known(ty)) => ty,
233 Some(ListElemState::Unknown | ListElemState::Mixed) => Type::Any,
234 None => elem_ty.clone(),
235 })),
236 "push" => {
237 let pushed_ty = params
238 .first()
239 .map(|param| {
240 if let Some(value) = self.get_value(param)
241 && (value.is_str() || value.is_native())
242 {
243 Ok(value.get_type())
244 } else {
245 self.infer_expr(param)
246 }
247 })
248 .transpose()?
249 .unwrap_or(Type::Any);
250 if let Some(idx) = self.local_var_idx_for_expr(target) {
251 let state = self.list_elem_state(idx).unwrap_or_else(|| if elem_ty.is_any() { ListElemState::Unknown } else { ListElemState::Known(elem_ty.clone()) });
252 let next_state = match state {
253 ListElemState::Unknown if pushed_ty.is_any() => ListElemState::Mixed,
254 ListElemState::Unknown => ListElemState::Known(pushed_ty),
255 ListElemState::Known(_) if pushed_ty.is_any() => ListElemState::Mixed,
256 ListElemState::Known(prev) => {
257 let merged = if prev == pushed_ty {
258 prev
259 } else if (prev.is_int() || prev.is_uint() || prev.is_float()) && (pushed_ty.is_int() || pushed_ty.is_uint() || pushed_ty.is_float()) {
260 prev + pushed_ty
261 } else {
262 Type::Any
263 };
264 if merged.is_any() { ListElemState::Mixed } else { ListElemState::Known(merged) }
265 }
266 ListElemState::Mixed => ListElemState::Mixed,
267 };
268 let next_elem = if let ListElemState::Known(ty) = &next_state { ty.clone() } else { Type::Any };
269 self.set_ty(idx, Type::List(std::rc::Rc::new(next_elem)));
270 self.set_list_elem_state(idx, Some(next_state));
271 }
272 Ok(Some(Type::Void))
273 }
274 "len" => Ok(Some(Type::I32)),
275 "is_list" | "is_null" => Ok(Some(Type::Bool)),
276 _ => Ok(None),
277 }
278 }
279
280 fn infer_return_expr(&mut self, expr: &Expr) -> Result<ReturnInfo> {
281 let ty = self.infer_expr(expr)?;
282 let shape = self.return_shape(expr, &ty);
283 let ty = if matches!(shape, Some(Type::Map | Type::List(_))) { Type::Any } else { ty };
284 Ok(ReturnInfo { ty, shape })
285 }
286
287 fn merge_return_info(span: Span, left: Option<ReturnInfo>, right: ReturnInfo) -> Result<ReturnInfo> {
288 let Some(left) = left else {
289 return Ok(right);
290 };
291 if let (Some(left_shape), Some(right_shape)) = (&left.shape, &right.shape)
292 && left_shape != right_shape
293 {
294 return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, right_shape)));
295 }
296 if let Some(left_shape) = &left.shape
297 && left_shape.is_struct()
298 && right.ty.is_any()
299 && right.shape.is_none()
300 {
301 return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, Type::Any)));
302 }
303 if let Some(right_shape) = &right.shape
304 && right_shape.is_struct()
305 && left.ty.is_any()
306 && left.shape.is_none()
307 {
308 return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", Type::Any, right_shape)));
309 }
310 let ty = Self::merge_return_type(span, Some(left.ty), right.ty)?;
311 Ok(ReturnInfo { ty, shape: left.shape.or(right.shape) })
312 }
313
314 fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
315 self.infer_returns(stmt, true).map(|(info, _)| info.map(|info| info.ty))
316 }
317
318 pub(crate) fn check_return_type(&mut self, stmt: &Stmt) -> Result<()> {
319 self.infer_returns(stmt, true).map(|_| ())
320 }
321
322 fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<ReturnInfo>, bool)> {
323 match &stmt.kind {
324 StmtKind::Return(Some(expr)) => Ok((Some(self.infer_return_expr(expr)?), true)),
325 StmtKind::Return(None) => Ok((Some(ReturnInfo { ty: Type::Void, shape: Some(Type::Void) }), true)),
326 StmtKind::Block(stmts) => {
327 let mut ret = None;
328 for (idx, stmt) in stmts.iter().enumerate() {
329 let (info, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
330 if let Some(info) = info {
331 self.update_pending_return_seed(&info.ty);
332 ret = Some(Self::merge_return_info(stmt.span, ret, info)?);
333 if let Some(ret) = &ret {
334 self.update_pending_return_seed(&ret.ty);
335 }
336 }
337 if always_returns {
338 return Ok((ret, true));
339 }
340 }
341 Ok((ret, false))
342 }
343 StmtKind::If { cond, then_body, else_body } => {
344 let cond_ty = self.infer_expr(cond)?;
345 if cond_ty != Type::Bool {
346 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
347 }
348 let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
349 if let Some(ret) = &ret {
350 self.update_pending_return_seed(&ret.ty);
351 }
352 let else_returns = if let Some(body) = else_body {
353 let (else_ty, else_returns) = self.infer_returns(body, tail)?;
354 if let Some(info) = else_ty {
355 self.update_pending_return_seed(&info.ty);
356 ret = Some(Self::merge_return_info(body.span, ret, info)?);
357 if let Some(ret) = &ret {
358 self.update_pending_return_seed(&ret.ty);
359 }
360 }
361 else_returns
362 } else {
363 false
364 };
365 Ok((ret, then_returns && else_returns))
366 }
367 StmtKind::While { cond, body } => {
368 let cond_ty = self.infer_expr(cond)?;
369 if cond_ty != Type::Bool {
370 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
371 }
372 self.infer_returns(body, false).map(|(ty, _)| (ty, false))
373 }
374 StmtKind::Loop(body) => self.infer_returns(body, false),
375 StmtKind::For { pat, range, body } => {
376 let ty = self.for_pattern_ty(range)?;
377 self.add_pattern_bindings_for_infer(pat, ty)?;
378 self.infer_returns(body, false).map(|(ty, _)| (ty, false))
379 }
380 StmtKind::Let { .. } => {
381 self.infer_stmt(stmt)?;
382 Ok((None, false))
383 }
384 StmtKind::Expr(expr, close) => {
385 let info = self.infer_return_expr(expr)?;
386 Ok(if *close || !tail { (None, false) } else { (Some(info), true) })
387 }
388 _ => {
389 self.infer_stmt(stmt)?;
390 Ok((None, false))
391 }
392 }
393 }
394
395 pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
396 match &expr.kind {
397 ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
398 ExprKind::Value(v) if v.is_list() => Ok(v.get_type()),
399 ExprKind::Value(v) if v.is_map() => Ok(Type::Any),
400 ExprKind::Value(v) => Ok(v.get_type()),
401 ExprKind::Const(idx) => Ok(match self.consts.get(*idx) {
402 Some(value) if value.is_str() => Type::Str,
403 Some(value) if value.is_list() && value.len() == 0 => Type::list_any(),
404 _ => Type::Any,
405 }),
406 ExprKind::Var(idx) => {
407 let idx = self.top() + (*idx as usize);
408 if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
409 }
410 ExprKind::Ident(ident) => {
411 for idx in (self.top()..self.names.len()).rev() {
412 if self.names[idx].eq(ident) && idx < self.tys.len() {
413 return self.symbols.get_type(&self.tys[idx]);
414 }
415 }
416 let id = self.symbols.get_id(ident).map_err(|_| Self::semantic_error(expr.span, format!("未找到标识符 {}", ident)))?;
417 match self.symbols.get_symbol(id)?.1 {
418 Symbol::Const { ty, .. } => Ok(ty.clone()),
419 Symbol::Static { ty, .. } => Ok(ty.clone()),
420 Symbol::Struct(ty, _) => Ok(ty.clone()),
421 Symbol::Fn { .. } => Ok(Type::Symbol { id, params: Vec::new() }),
422 Symbol::Native(ty) => Ok(ty.clone()),
423 s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
424 }
425 }
426 ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
427 Symbol::Const { ty, .. } => Ok(ty.clone()),
428 Symbol::Static { ty, .. } => Ok(ty.clone()),
429 Symbol::Struct(ty, _) => Ok(ty.clone()),
430 Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
431 Symbol::Native(ty) => Ok(ty.clone()),
432 s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
433 },
434 ExprKind::Generic { obj, params } => {
435 let params = params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect();
436 match self.infer_expr(obj)? {
437 Type::Symbol { id, .. } => Ok(Type::Symbol { id, params }),
438 _ => Ok(Type::Any),
439 }
440 }
441 ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
442 ExprKind::Unary { op, value } => match op {
443 UnaryOp::Not => {
444 let ty = self.infer_expr(value.as_ref())?;
445 if ty.is_int() || ty.is_uint() { Ok(ty) } else { Ok(Type::Bool) }
446 }
447 UnaryOp::Neg => self.infer_expr(value.as_ref()),
448 UnaryOp::Unknow => Ok(Type::Any),
449 },
450 ExprKind::Binary { left, op, right } => {
451 if op == &BinaryOp::Assign
452 && let ExprKind::Tuple(left_items) | ExprKind::List(left_items) = &left.kind
453 {
454 if let ExprKind::Tuple(right_items) | ExprKind::List(right_items) = &right.kind {
455 if left_items.len() != right_items.len() {
456 return Err(Self::semantic_error(expr.span, format!("多重赋值数量不匹配: 左侧 {} 个,右侧 {} 个", left_items.len(), right_items.len())));
457 }
458 for item in right_items {
459 let _ = self.infer_expr(item)?;
460 }
461 } else {
462 let _ = self.infer_expr(right)?;
463 }
464 return Ok(Type::Void);
465 }
466 let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
467 let ty = if op.is_logic() {
468 Type::Bool
469 } else if op == &BinaryOp::Idx {
470 let left_ty = self.infer_expr(left)?;
471 if let Type::Array(elem_ty, _) = left_ty {
472 (*elem_ty).clone()
473 } else if let Type::Vec(elem_ty, _) = left_ty {
474 (*elem_ty).clone()
475 } else if let Type::List(elem_ty) = left_ty {
476 (*elem_ty).clone()
477 } else {
478 let left_ty = self.symbols.get_type(&left_ty)?;
479 let right_ty = if right.is_value() || right.is_const() {
480 let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
481 if right_value.is_str() {
482 if left_ty.is_any() {
483 return Ok(Type::Any);
484 }
485 if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
486 return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
487 }
488 } else if let Type::Struct { fields, .. } = &left_ty
489 && let Some(idx) = right_value.as_int()
490 {
491 return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
492 }
493 right_value.get_type()
494 } else {
495 self.infer_expr(right)?
496 };
497 if right_ty.is_int() || right_ty.is_uint() {
498 if left_ty.is_any() {
499 return Ok(Type::Any);
500 }
501 let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
502 let fn_ty = self.symbols.get_type(&s)?;
503 return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
504 }
505 if left_ty.is_any() {
506 return Ok(Type::Any);
507 }
508 Type::Any
509 }
510 } else {
511 let left_ty = self.infer_expr(left)?;
512 let right_ty = self.infer_expr(right)?;
513 if op == &BinaryOp::Assign {
514 if !left_ty.is_any() && right_ty.is_any() { left_ty } else { right_ty }
515 } else if op.is_assign() && !left_ty.is_any() && right_ty.is_any() {
516 left_ty
517 } else {
518 left_ty + right_ty
519 }
520 };
521 assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
522 Ok(ty)
523 }
524 ExprKind::Call { obj, params } => {
525 if let ExprKind::Assoc { ty, name } = &obj.kind {
526 let base_name = match ty {
527 Type::Ident { name, .. } => name.clone(),
528 Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
529 _ => return Ok(Type::Any),
530 };
531 let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
532 let generic_args = match ty {
533 Type::Ident { params, .. } | Type::Symbol { params, .. } => params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>(),
534 _ => Vec::new(),
535 };
536 let mut args = Vec::new();
537 for p in params {
538 args.push(self.infer_expr(p)?);
539 }
540 self.infer_fn_with_params(id, &args, &generic_args)
541 } else if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
542 let mut args = Vec::new();
543 for p in params {
544 args.push(self.infer_expr(p)?);
545 }
546 self.infer_fn_with_params(*id, &args, generic_args)
547 } else if let ExprKind::Generic { obj, params: generic_args } = &obj.kind {
548 let Type::Symbol { id, .. } = self.infer_expr(obj)? else {
549 return Ok(Type::Any);
550 };
551 let generic_args = generic_args.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>();
552 let mut args = Vec::new();
553 for p in params {
554 args.push(self.infer_expr(p)?);
555 }
556 self.infer_fn_with_params(id, &args, &generic_args)
557 } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
558 let base_name = match ty {
559 Type::Ident { name, .. } => name.clone(),
560 Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
561 _ => return Ok(Type::Any),
562 };
563 let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
564 let mut args = vec![self.infer_expr(target)?];
565 for p in params {
566 args.push(self.infer_expr(p)?);
567 }
568 self.infer_fn(id, &args)
569 } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
570 let method = self.symbols.get_symbol(*id).ok().and_then(|(name, _)| name.rsplit_once("::").map(|(_, method)| method.to_string()));
571 if let Some(target) = obj_expr
572 && let Some(method) = method
573 {
574 let target_ty = self.infer_expr(target)?;
575 if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &target_ty
576 && let Some(ret_ty) = self.infer_list_method(target, elem_ty, method.as_str(), params)?
577 {
578 return Ok(ret_ty);
579 }
580 }
581 let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
582 for p in params {
583 args.push(self.infer_expr(p)?);
584 }
585 self.infer_fn(*id, &args)
586 } else if let ExprKind::Ident(name) = &obj.kind {
587 for idx in (self.top()..self.names.len()).rev() {
588 if self.names[idx].eq(name) && idx < self.tys.len() {
589 return if let Type::Symbol { id, .. } = &self.tys[idx] {
590 let id = *id;
591 let mut args = Vec::new();
592 for p in params {
593 args.push(self.infer_expr(p)?);
594 }
595 self.infer_fn(id, &args)
596 } else {
597 Ok(Type::Any)
598 };
599 }
600 }
601 let Ok(id) = self.symbols.get_id(name) else {
602 return Ok(Type::Any);
603 };
604 if !self.symbols.get_symbol(id)?.1.is_fn() {
605 return Err(Self::semantic_error(obj.span, format!("符号 {} 不是函数", name)));
606 }
607 let mut args = Vec::new();
608 for p in params {
609 args.push(self.infer_expr(p)?);
610 }
611 self.infer_fn(id, &args)
612 } else if obj.is_idx() {
613 let (target, _, method) = obj.clone().binary().unwrap();
614 let ty = self.infer_expr(&target)?;
615 if let Some(method) = self.get_value(&method) {
616 let method = method.as_str();
617 if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &ty
618 && let Some(ret_ty) = self.infer_list_method(&target, elem_ty, method, params)?
619 {
620 return Ok(ret_ty);
621 }
622 let fn_ty = match self.get_field(&ty, method) {
623 Ok((_, fn_ty)) => fn_ty,
624 Err(_) => {
625 let id = self.symbols.get_id(method)?;
626 if self.symbols.get_symbol(id)?.1.is_fn() {
627 Type::Symbol { id, params: Vec::new() }
628 } else {
629 return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
630 }
631 }
632 };
633 if let Type::Symbol { id, .. } = fn_ty {
634 let mut args = vec![ty];
635 for p in params {
636 args.push(self.infer_expr(p)?);
637 }
638 self.infer_fn(id, &args)
639 } else {
640 Ok(fn_ty)
641 }
642 } else {
643 Ok(Type::Any)
644 }
645 } else if let ExprKind::Var(idx) = &obj.kind {
646 let idx = self.top() + (*idx as usize);
647 if idx < self.tys.len()
648 && let Type::Symbol { id, .. } = self.tys[idx]
649 {
650 let mut args = Vec::new();
651 for p in params {
652 args.push(self.infer_expr(p)?);
653 }
654 self.infer_fn(id, &args)
655 } else {
656 Ok(Type::Any)
657 }
658 } else if obj.is_value() {
659 Ok(Type::Void)
660 } else {
661 Ok(Type::Any)
662 }
663 }
664 ExprKind::Typed { ty, .. } => self.symbols.get_type(ty),
665 ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
666 ExprKind::Repeat { value, len } => {
667 let value_ty = self.infer_expr(value)?;
668 let len = self.symbols.get_type(len).unwrap_or_else(|_| len.clone());
669 if let Type::ConstInt(len) = len {
670 let len = u32::try_from(len).map_err(|_| Self::semantic_error(expr.span, "重复数组长度必须是非负 u32"))?;
671 Ok(Type::Array(std::rc::Rc::new(value_ty), len))
672 } else {
673 Ok(Type::ArrayParam(std::rc::Rc::new(value_ty), std::rc::Rc::new(len)))
674 }
675 }
676 ExprKind::List(items) => {
677 if items.is_empty() {
678 return Ok(Type::list_any());
679 }
680 let mut elem_ty = Type::Any;
681 for item in items {
682 let item_ty = self.infer_expr(item)?;
683 elem_ty = if elem_ty.is_any() { item_ty } else { elem_ty + item_ty };
684 }
685 Ok(Type::Array(std::rc::Rc::new(elem_ty), items.len() as u32))
686 }
687 ExprKind::Range { start, stop, .. } => {
688 let start_ty = self.infer_expr(start)?;
689 let stop_ty = self.infer_expr(stop)?;
690 Ok(Self::merge_range_bound_types(start_ty, stop_ty))
691 }
692 _ => Ok(Type::Any),
693 }
694 }
695
696 fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
697 let mut fn_tys = Vec::new();
698 for (i, ty) in tys.iter().enumerate() {
699 if !ty.is_any() {
700 fn_tys.push(ty.clone());
701 } else if let Some(arg_ty) = arg_tys.get(i) {
702 fn_tys.push(self.symbols.get_type(arg_ty)?);
703 } else {
704 fn_tys.push(Type::Any);
705 }
706 }
707 Ok(fn_tys)
708 }
709
710 fn is_optimizable_local_ty(ty: &Type) -> bool {
711 ty.is_bool() || ty.is_native()
712 }
713
714 fn is_optimizable_list_elem_ty(ty: &Type) -> bool {
715 matches!(ty, Type::Bool | Type::U8 | Type::I8 | Type::U16 | Type::I16 | Type::U32 | Type::I32 | Type::F32 | Type::U64 | Type::I64 | Type::F64 | Type::Str)
716 }
717
718 fn local_type_hint_at(&self, pos: usize) -> Option<Type> {
719 let ty = self.tys.get(pos)?;
720 match ty {
721 Type::List(_) => self.list_elem_states.get(pos).cloned().flatten().and_then(|state| {
722 if let ListElemState::Known(elem_ty) = state
723 && Self::is_optimizable_list_elem_ty(&elem_ty)
724 {
725 Some(Type::List(std::rc::Rc::new(elem_ty)))
726 } else {
727 None
728 }
729 }),
730 ty if Self::is_optimizable_local_ty(ty) => Some(ty.clone()),
731 _ => None,
732 }
733 }
734
735 fn collect_local_type_hints(&self) -> Vec<Option<Type>> {
736 (self.top()..self.tys.len()).map(|pos| self.local_type_hint_at(pos)).collect()
737 }
738
739 fn set_local_type_hints(&mut self, id: u32, generic_args: &[Type], fn_tys: &[Type], hints: Vec<Option<Type>>) {
740 let items = self.local_type_hints.entry(id).or_default();
741 if let Some(item) = items.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys) {
742 item.2 = hints;
743 } else {
744 items.push((generic_args.to_vec(), fn_tys.to_vec(), hints));
745 }
746 }
747
748 pub fn inferred_local_type_hints(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Vec<Option<Type>> {
749 self.local_type_hints.get(&id).and_then(|items| items.iter().find(|item| item.0 == generic_args && item.1 == fn_tys)).map(|item| item.2.clone()).unwrap_or_default()
750 }
751
752 pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
753 self.infer_fn_with_params(id, arg_tys, &[])
754 }
755
756 pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
757 let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
758 if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
759 if let Type::Fn { tys, ret: _ } = ty {
760 let resolved_generic_args = crate::resolve_generic_args_from_types(&generic_params, &tys, arg_tys, generic_args)?;
761 let generic_args = resolved_generic_args.as_slice();
762 let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
763 let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
764 let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
765 let body = if generic_params.is_empty() {
766 body
767 } else {
768 let mut compile_tys = tys.clone();
769 let mut compile_cap = cap.clone();
770 let saved_state = self.take_local_state();
771 if let Some((module, _)) = name.split_once("::") {
772 self.symbols.push_module_scope(module.into());
773 }
774 let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
775 if name.contains("::") {
776 self.symbols.pop_module_scope();
777 }
778 self.restore_local_state(saved_state);
779 Stmt::new(StmtKind::Block(compiled?), Span::default())
780 };
781 if let Some(fns) = self.fns.get_mut(&id) {
782 for f in fns.iter() {
783 if f.0 == generic_args && f.1 == fn_tys {
784 return match &f.2 {
785 FnInferRet::Done(ret_ty) => self.symbols.get_type(ret_ty),
786 FnInferRet::Pending(seed) => seed.as_ref().map(|ty| self.symbols.get_type(ty)).unwrap_or_else(|| {
787 if self.infer_stack.iter().any(|(sid, sargs, _)| *sid == id && sargs == generic_args) {
789 if let Some(base_ty) = self.try_find_base_return_ty(&body) {
790 return self.symbols.get_type(&base_ty);
791 }
792 }
793 Ok(Type::Any)
794 }),
795 };
796 }
797 }
798 fns.push((generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None)));
799 } else {
800 self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None))]);
801 }
802 if self.pending_return_seed(id, generic_args, &fn_tys).is_none() {
804 if let Some(base_ty) = self.try_find_base_return_ty_with_scope(&body, id, &name, &args, &fn_tys) {
805 if let Some(fns) = self.fns.get_mut(&id) {
806 if let Some(item) = fns.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys)
807 && let FnInferRet::Pending(seed) = &mut item.2
808 && seed.is_none()
809 {
810 *seed = Some(base_ty);
811 }
812 }
813 }
814 }
815 let mut ret_ty = None;
816 let mut local_type_hints = Vec::new();
817 for _ in 0..4 {
818 let before_seed = self.pending_return_seed(id, generic_args, &fn_tys);
819 let saved_state = self.take_local_state();
820 self.frames.push(0);
821 for (arg, ty) in args.iter().zip(fn_tys.iter()) {
822 self.add_name(arg.clone());
823 self.add_ty(ty.clone());
824 }
825 for c in cap.vars.iter() {
826 if let Some((name, ty)) = cap.names.get(*c) {
827 self.add_name(name.clone());
828 self.add_ty(ty.clone());
829 } else {
830 self.add_name("".into());
831 self.add_ty(Type::Any);
832 }
833 }
834 self.infer_stack.push((id, generic_args.to_vec(), fn_tys.clone()));
835 let pass_ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
836 self.infer_stack.pop();
837 let pass_local_type_hints = self.collect_local_type_hints();
838 self.restore_local_state(saved_state);
839 let pass_ret_ty = match pass_ret_ty {
840 Ok(pass_ret_ty) => self.symbols.get_type(&pass_ret_ty).unwrap_or(pass_ret_ty),
841 Err(err) => {
842 log::error!("infer_fn {} failed: {:?}", name, err);
843 let should_remove = self
844 .fns
845 .get_mut(&id)
846 .map(|fns| {
847 fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || !matches!(item.2, FnInferRet::Pending(_)));
848 fns.is_empty()
849 })
850 .unwrap_or(false);
851 if should_remove {
852 self.fns.remove(&id);
853 }
854 return Err(err);
855 }
856 };
857 if !pass_ret_ty.is_any() {
858 self.update_pending_return_seed(&pass_ret_ty);
859 ret_ty = Some(pass_ret_ty.clone());
860 } else if ret_ty.is_none() {
861 ret_ty = Some(pass_ret_ty);
862 }
863 local_type_hints = pass_local_type_hints;
864 let after_seed = self.pending_return_seed(id, generic_args, &fn_tys);
865 if before_seed == after_seed {
866 break;
867 }
868 }
869 let ret_ty = ret_ty.unwrap_or(Type::Any);
870 self.fns.get_mut(&id).map(|f| {
871 f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = FnInferRet::Done(ret_ty.clone()));
872 });
873 self.set_local_type_hints(id, generic_args, &fn_tys, local_type_hints);
874 if generic_args.is_empty()
875 && let Some((_, Symbol::Fn { ty: Type::Fn { ret, .. }, .. })) = self.symbols.get_symbol_mut(id)
876 && ret.is_any()
877 {
878 *ret = std::rc::Rc::new(ret_ty.clone());
879 }
880 Ok(ret_ty)
881 } else {
882 Ok(Type::Any)
883 }
884 } else if let Symbol::Native(f) = s {
885 if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
886 } else if matches!(s, Symbol::Null) {
887 Ok(Type::Any)
888 } else {
889 Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
890 }
891 }
892
893 pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
894 match &stmt.kind {
895 StmtKind::Expr(expr, close) => {
896 if !close {
897 self.infer_expr(expr)
898 } else {
899 self.infer_expr(expr)?;
900 Ok(Type::Void)
901 }
902 }
903 StmtKind::Return(expr) => {
904 if let Some(e) = expr {
905 self.infer_expr(e)
906 } else {
907 Ok(Type::Void)
908 }
909 }
910 StmtKind::Block(stmts) => {
911 for (idx, stmt) in stmts.iter().enumerate() {
912 let ty = self.infer_stmt(stmt)?;
913 if stmt.is_return() || idx == stmts.len() - 1 {
914 return Ok(ty);
915 }
916 }
917 Ok(Type::Void)
918 }
919 StmtKind::If { then_body, else_body, .. } => {
920 let then_ty = self.infer_stmt(then_body)?;
921 if let Some(e) = else_body {
922 let else_ty = self.infer_stmt(e)?;
923 if then_ty != else_ty {
924 log::debug!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
925 return Self::merge_return_type(stmt.span, Some(then_ty), else_ty);
926 }
927 }
928 if else_body.is_none() {
929 return Ok(Type::Void);
930 }
931 Ok(then_ty)
932 }
933 StmtKind::While { cond, body } => {
934 let cond_ty = self.infer_expr(cond)?;
935 if cond_ty != Type::Bool {
936 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
937 }
938 self.infer_stmt(body)
939 }
940 StmtKind::For { pat, range, body } => {
941 let ty = self.for_pattern_ty(range)?;
942 self.add_pattern_bindings_for_infer(pat, ty)?;
943 self.infer_stmt(body)
944 }
945 StmtKind::Let { pat, value } => {
946 let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
947 self.add_pattern_bindings_for_infer(pat, expr_ty)?;
948 Ok(Type::Void)
949 }
950 _ => Ok(Type::Void),
951 }
952 }
953}