Skip to main content

compiler/
infer.rs

1use super::{Compiler, FnInferRet, ListElemState, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, Pattern, PatternKind, Span, Stmt, StmtKind, UnaryOp};
5use smol_str::SmolStr;
6
7#[derive(Clone)]
8struct ReturnInfo {
9    ty: Type,
10    shape: Option<Type>,
11}
12
13impl Compiler {
14    fn current_infer_key(&self) -> Option<(u32, Vec<Type>, Vec<Type>)> {
15        self.infer_stack.last().cloned()
16    }
17
18    fn pending_return_seed(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Option<Type> {
19        self.fns.get(&id).and_then(|fns| {
20            fns.iter().find_map(|item| {
21                if item.0 == generic_args
22                    && item.1 == fn_tys
23                    && let FnInferRet::Pending(seed) = &item.2
24                {
25                    seed.clone()
26                } else {
27                    None
28                }
29            })
30        })
31    }
32
33    fn update_pending_return_seed(&mut self, ty: &Type) {
34        if ty.is_any() {
35            return;
36        }
37        let Some((id, generic_args, fn_tys)) = self.current_infer_key() else {
38            return;
39        };
40        let Some(fns) = self.fns.get_mut(&id) else {
41            return;
42        };
43        if let Some(item) = fns.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys)
44            && let FnInferRet::Pending(seed) = &mut item.2
45        {
46            let next = seed.take().map(|prev| prev + ty.clone()).unwrap_or_else(|| ty.clone());
47            *seed = Some(next);
48        }
49    }
50
51    /// 扫描函数体,查找第一个非递归路径上的返回值类型(仅处理字面量)。
52    fn try_find_base_return_ty(&self, body: &Stmt) -> Option<Type> {
53        match &body.kind {
54            StmtKind::Block(stmts) => stmts.iter().find_map(|s| self.try_find_base_return_ty(s)),
55            StmtKind::If { then_body, else_body, .. } => self.try_find_base_return_ty(then_body)
56                .or_else(|| else_body.as_ref().and_then(|b| self.try_find_base_return_ty(b))),
57            StmtKind::Return(Some(expr)) => Self::try_literal_type(expr),
58            StmtKind::Expr(expr, false) => Self::try_literal_type(expr),
59            _ => None,
60        }
61    }
62
63    /// 带作用域的 base case 返回类型查找
64    fn try_find_base_return_ty_with_scope(&mut self, body: &Stmt, fn_id: u32, fn_name: &str, args: &[SmolStr], fn_tys: &[Type]) -> Option<Type> {
65        let saved_state = self.take_local_state();
66        self.frames.push(0);
67        for (arg, ty) in args.iter().zip(fn_tys.iter()) {
68            self.add_name(arg.clone());
69            self.add_ty(ty.clone());
70        }
71        let result = self.try_find_base_return_ty_with_scope_inner(body, fn_id, fn_name);
72        self.restore_local_state(saved_state);
73        result
74    }
75
76    fn try_find_base_return_ty_with_scope_inner(&mut self, body: &Stmt, fn_id: u32, fn_name: &str) -> Option<Type> {
77        match &body.kind {
78            StmtKind::Block(stmts) => stmts.iter().find_map(|s| self.try_find_base_return_ty_with_scope_inner(s, fn_id, fn_name)),
79            StmtKind::If { then_body, else_body, .. } => self.try_find_base_return_ty_with_scope_inner(then_body, fn_id, fn_name)
80                .or_else(|| else_body.as_ref().and_then(|b| self.try_find_base_return_ty_with_scope_inner(b, fn_id, fn_name))),
81            StmtKind::Return(Some(expr)) => {
82                if Self::expr_calls_fn(expr, fn_id, fn_name) { None }
83                else { self.infer_return_expr(expr).ok().map(|info| info.ty) }
84            }
85            StmtKind::Expr(expr, false) => {
86                if Self::expr_calls_fn(expr, fn_id, fn_name) { None }
87                else { self.infer_return_expr(expr).ok().map(|info| info.ty) }
88            }
89            _ => None,
90        }
91    }
92
93    fn expr_calls_fn(expr: &Expr, fn_id: u32, fn_name: &str) -> bool {
94        match &expr.kind {
95            ExprKind::Call { obj, params } => {
96                if let ExprKind::Id(id, _) = &obj.kind { return *id == fn_id; }
97                if let ExprKind::Ident(name) = &obj.kind {
98                    if name.as_str() == fn_name || fn_name.ends_with(&format!("::{}", name)) { return true; }
99                }
100                params.iter().any(|p| Self::expr_calls_fn(p, fn_id, fn_name))
101            }
102            ExprKind::Binary { left, op: _, right } => Self::expr_calls_fn(left, fn_id, fn_name) || Self::expr_calls_fn(right, fn_id, fn_name),
103            ExprKind::Unary { op: _, value } => Self::expr_calls_fn(value, fn_id, fn_name),
104            ExprKind::Typed { value, ty: _ } => Self::expr_calls_fn(value, fn_id, fn_name),
105            _ => false,
106        }
107    }
108
109    fn try_literal_type(expr: &Expr) -> Option<Type> {
110        match &expr.kind {
111            ExprKind::Value(v) => Some(v.get_type()),
112            ExprKind::Unary { op: UnaryOp::Neg, value } => Self::try_literal_type(value),
113            _ => None,
114        }
115    }
116
117    fn add_pattern_bindings_for_infer(&mut self, pat: &Pattern, expr_ty: Type) -> Result<()> {
118        match &pat.kind {
119            PatternKind::Ident { name, ty } => {
120                let annotated_ty = self.symbols.get_type(ty)?;
121                self.add_name(name.clone());
122                self.add_ty(if annotated_ty.is_any() { expr_ty } else { annotated_ty });
123            }
124            PatternKind::Var { idx, .. } => self.set_ty(*idx, expr_ty),
125            PatternKind::Tuple(pats) => {
126                if let Type::Tuple(tys) = expr_ty {
127                    for (pat, ty) in pats.iter().zip(tys) {
128                        self.add_pattern_bindings_for_infer(pat, ty)?;
129                    }
130                } else {
131                    for pat in pats {
132                        self.add_pattern_bindings_for_infer(pat, Type::Any)?;
133                    }
134                }
135            }
136            PatternKind::List { elems, .. } => {
137                for pat in elems {
138                    self.add_pattern_bindings_for_infer(pat, Type::Any)?;
139                }
140            }
141            PatternKind::Wildcard => {
142                self.add_name("".into());
143                self.add_ty(expr_ty);
144            }
145            PatternKind::Literal(_) | PatternKind::Member(_, _) | PatternKind::Idx(_, _) => {}
146        }
147        Ok(())
148    }
149
150    fn for_pattern_ty(&mut self, range: &Expr) -> Result<Type> {
151        if matches!(range.kind, ExprKind::Range { .. }) {
152            return self.infer_range_expr(range);
153        }
154        Ok(match self.infer_expr(range)? {
155            Type::Array(elem_ty, _) | Type::Vec(elem_ty, _) | Type::List(elem_ty) => elem_ty.as_ref().clone(),
156            _ => Type::Any,
157        })
158    }
159
160    fn infer_range_expr(&mut self, range: &Expr) -> Result<Type> {
161        let ExprKind::Range { start, stop, .. } = &range.kind else {
162            return self.infer_expr(range);
163        };
164        let start_ty = self.infer_expr(start)?;
165        let stop_ty = self.infer_expr(stop)?;
166        Ok(Self::merge_range_bound_types(start_ty, stop_ty))
167    }
168
169    fn merge_range_bound_types(start_ty: Type, stop_ty: Type) -> Type {
170        if start_ty.is_any() {
171            stop_ty
172        } else if stop_ty.is_any() {
173            start_ty
174        } else if start_ty == Type::I32 && stop_ty.is_uint() {
175            stop_ty
176        } else if stop_ty == Type::I32 && start_ty.is_uint() {
177            start_ty
178        } else {
179            start_ty + stop_ty
180        }
181    }
182
183    fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
184        match left {
185            Some(left) if left == right => Ok(left),
186            Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
187            Some(left) if left.is_any() || right.is_any() => Ok(Type::Any),
188            Some(left) => Ok(left + right),
189            None => Ok(right),
190        }
191    }
192
193    fn return_shape(&self, expr: &Expr, ty: &Type) -> Option<Type> {
194        if !ty.is_any() {
195            return match ty {
196                Type::Struct { .. } => Some(ty.clone()),
197                Type::Map => Some(Type::Map),
198                Type::List(elem) | Type::Array(elem, _) => Some(Type::List(elem.clone())),
199                _ => None,
200            };
201        }
202        match &expr.kind {
203            ExprKind::List(_) | ExprKind::Tuple(_) => Some(Type::list_any()),
204            ExprKind::Dict(_) => Some(Type::Map),
205            ExprKind::Value(value) => Self::dynamic_return_shape(value.get_type()),
206            ExprKind::Const(idx) => self.consts.get(*idx).and_then(|value| Self::dynamic_return_shape(value.get_type())),
207            ExprKind::Typed { ty, .. } => Some(ty.clone()),
208            _ => None,
209        }
210    }
211
212    fn dynamic_return_shape(ty: Type) -> Option<Type> {
213        match ty {
214            Type::Map => Some(Type::Map),
215            Type::List(elem) => Some(Type::List(elem)),
216            Type::Array(elem, _) => Some(Type::List(elem)),
217            _ => None,
218        }
219    }
220
221    fn local_var_idx_for_expr(&self, expr: &Expr) -> Option<u32> {
222        match &expr.kind {
223            ExprKind::Var(idx) => Some(*idx),
224            ExprKind::Ident(name) => (self.top()..self.names.len()).rev().find(|idx| self.names[*idx].eq(name)).map(|idx| (idx - self.top()) as u32),
225            _ => None,
226        }
227    }
228
229    fn infer_list_method(&mut self, target: &Expr, elem_ty: &Type, method: &str, params: &[Expr]) -> Result<Option<Type>> {
230        match method {
231            "get_idx" | "pop" => Ok(Some(match self.local_var_idx_for_expr(target).and_then(|idx| self.list_elem_state(idx)) {
232                Some(ListElemState::Known(ty)) => ty,
233                Some(ListElemState::Unknown | ListElemState::Mixed) => Type::Any,
234                None => elem_ty.clone(),
235            })),
236            "push" => {
237                let pushed_ty = params
238                    .first()
239                    .map(|param| {
240                        if let Some(value) = self.get_value(param)
241                            && (value.is_str() || value.is_native())
242                        {
243                            Ok(value.get_type())
244                        } else {
245                            self.infer_expr(param)
246                        }
247                    })
248                    .transpose()?
249                    .unwrap_or(Type::Any);
250                if let Some(idx) = self.local_var_idx_for_expr(target) {
251                    let state = self.list_elem_state(idx).unwrap_or_else(|| if elem_ty.is_any() { ListElemState::Unknown } else { ListElemState::Known(elem_ty.clone()) });
252                    let next_state = match state {
253                        ListElemState::Unknown if pushed_ty.is_any() => ListElemState::Mixed,
254                        ListElemState::Unknown => ListElemState::Known(pushed_ty),
255                        ListElemState::Known(_) if pushed_ty.is_any() => ListElemState::Mixed,
256                        ListElemState::Known(prev) => {
257                            let merged = if prev == pushed_ty {
258                                prev
259                            } else if (prev.is_int() || prev.is_uint() || prev.is_float()) && (pushed_ty.is_int() || pushed_ty.is_uint() || pushed_ty.is_float()) {
260                                prev + pushed_ty
261                            } else {
262                                Type::Any
263                            };
264                            if merged.is_any() { ListElemState::Mixed } else { ListElemState::Known(merged) }
265                        }
266                        ListElemState::Mixed => ListElemState::Mixed,
267                    };
268                    let next_elem = if let ListElemState::Known(ty) = &next_state { ty.clone() } else { Type::Any };
269                    self.set_ty(idx, Type::List(std::rc::Rc::new(next_elem)));
270                    self.set_list_elem_state(idx, Some(next_state));
271                }
272                Ok(Some(Type::Void))
273            }
274            "len" => Ok(Some(Type::I32)),
275            "is_list" | "is_null" => Ok(Some(Type::Bool)),
276            _ => Ok(None),
277        }
278    }
279
280    fn infer_return_expr(&mut self, expr: &Expr) -> Result<ReturnInfo> {
281        let ty = self.infer_expr(expr)?;
282        let shape = self.return_shape(expr, &ty);
283        let ty = if matches!(shape, Some(Type::Map | Type::List(_))) { Type::Any } else { ty };
284        Ok(ReturnInfo { ty, shape })
285    }
286
287    fn merge_return_info(span: Span, left: Option<ReturnInfo>, right: ReturnInfo) -> Result<ReturnInfo> {
288        let Some(left) = left else {
289            return Ok(right);
290        };
291        if let (Some(left_shape), Some(right_shape)) = (&left.shape, &right.shape)
292            && left_shape != right_shape
293        {
294            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, right_shape)));
295        }
296        if let Some(left_shape) = &left.shape
297            && left_shape.is_struct()
298            && right.ty.is_any()
299            && right.shape.is_none()
300        {
301            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, Type::Any)));
302        }
303        if let Some(right_shape) = &right.shape
304            && right_shape.is_struct()
305            && left.ty.is_any()
306            && left.shape.is_none()
307        {
308            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", Type::Any, right_shape)));
309        }
310        let ty = Self::merge_return_type(span, Some(left.ty), right.ty)?;
311        Ok(ReturnInfo { ty, shape: left.shape.or(right.shape) })
312    }
313
314    fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
315        self.infer_returns(stmt, true).map(|(info, _)| info.map(|info| info.ty))
316    }
317
318    pub(crate) fn check_return_type(&mut self, stmt: &Stmt) -> Result<()> {
319        self.infer_returns(stmt, true).map(|_| ())
320    }
321
322    fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<ReturnInfo>, bool)> {
323        match &stmt.kind {
324            StmtKind::Return(Some(expr)) => Ok((Some(self.infer_return_expr(expr)?), true)),
325            StmtKind::Return(None) => Ok((Some(ReturnInfo { ty: Type::Void, shape: Some(Type::Void) }), true)),
326            StmtKind::Block(stmts) => {
327                let mut ret = None;
328                for (idx, stmt) in stmts.iter().enumerate() {
329                    let (info, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
330                    if let Some(info) = info {
331                        self.update_pending_return_seed(&info.ty);
332                        ret = Some(Self::merge_return_info(stmt.span, ret, info)?);
333                        if let Some(ret) = &ret {
334                            self.update_pending_return_seed(&ret.ty);
335                        }
336                    }
337                    if always_returns {
338                        return Ok((ret, true));
339                    }
340                }
341                Ok((ret, false))
342            }
343            StmtKind::If { cond, then_body, else_body } => {
344                let cond_ty = self.infer_expr(cond)?;
345                if cond_ty != Type::Bool {
346                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
347                }
348                let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
349                if let Some(ret) = &ret {
350                    self.update_pending_return_seed(&ret.ty);
351                }
352                let else_returns = if let Some(body) = else_body {
353                    let (else_ty, else_returns) = self.infer_returns(body, tail)?;
354                    if let Some(info) = else_ty {
355                        self.update_pending_return_seed(&info.ty);
356                        ret = Some(Self::merge_return_info(body.span, ret, info)?);
357                        if let Some(ret) = &ret {
358                            self.update_pending_return_seed(&ret.ty);
359                        }
360                    }
361                    else_returns
362                } else {
363                    false
364                };
365                Ok((ret, then_returns && else_returns))
366            }
367            StmtKind::While { cond, body } => {
368                let cond_ty = self.infer_expr(cond)?;
369                if cond_ty != Type::Bool {
370                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
371                }
372                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
373            }
374            StmtKind::Loop(body) => self.infer_returns(body, false),
375            StmtKind::For { pat, range, body } => {
376                let ty = self.for_pattern_ty(range)?;
377                self.add_pattern_bindings_for_infer(pat, ty)?;
378                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
379            }
380            StmtKind::Let { .. } => {
381                self.infer_stmt(stmt)?;
382                Ok((None, false))
383            }
384            StmtKind::Expr(expr, close) => {
385                let info = self.infer_return_expr(expr)?;
386                Ok(if *close || !tail { (None, false) } else { (Some(info), true) })
387            }
388            _ => {
389                self.infer_stmt(stmt)?;
390                Ok((None, false))
391            }
392        }
393    }
394
395    pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
396        match &expr.kind {
397            ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
398            ExprKind::Value(v) if v.is_list() => Ok(v.get_type()),
399            ExprKind::Value(v) if v.is_map() => Ok(Type::Any),
400            ExprKind::Value(v) => Ok(v.get_type()),
401            ExprKind::Const(idx) => Ok(match self.consts.get(*idx) {
402                Some(value) if value.is_str() => Type::Str,
403                Some(value) if value.is_list() && value.len() == 0 => Type::list_any(),
404                _ => Type::Any,
405            }),
406            ExprKind::Var(idx) => {
407                let idx = self.top() + (*idx as usize);
408                if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
409            }
410            ExprKind::Ident(ident) => {
411                for idx in (self.top()..self.names.len()).rev() {
412                    if self.names[idx].eq(ident) && idx < self.tys.len() {
413                        return self.symbols.get_type(&self.tys[idx]);
414                    }
415                }
416                let id = self.symbols.get_id(ident).map_err(|_| Self::semantic_error(expr.span, format!("未找到标识符 {}", ident)))?;
417                match self.symbols.get_symbol(id)?.1 {
418                    Symbol::Const { ty, .. } => Ok(ty.clone()),
419                    Symbol::Static { ty, .. } => Ok(ty.clone()),
420                    Symbol::Struct(ty, _) => Ok(ty.clone()),
421                    Symbol::Fn { .. } => Ok(Type::Symbol { id, params: Vec::new() }),
422                    Symbol::Native(ty) => Ok(ty.clone()),
423                    s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
424                }
425            }
426            ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
427                Symbol::Const { ty, .. } => Ok(ty.clone()),
428                Symbol::Static { ty, .. } => Ok(ty.clone()),
429                Symbol::Struct(ty, _) => Ok(ty.clone()),
430                Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
431                Symbol::Native(ty) => Ok(ty.clone()),
432                s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
433            },
434            ExprKind::Generic { obj, params } => {
435                let params = params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect();
436                match self.infer_expr(obj)? {
437                    Type::Symbol { id, .. } => Ok(Type::Symbol { id, params }),
438                    _ => Ok(Type::Any),
439                }
440            }
441            ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
442            ExprKind::Unary { op, value } => match op {
443                UnaryOp::Not => {
444                    let ty = self.infer_expr(value.as_ref())?;
445                    if ty.is_int() || ty.is_uint() { Ok(ty) } else { Ok(Type::Bool) }
446                }
447                UnaryOp::Neg => self.infer_expr(value.as_ref()),
448                UnaryOp::Unknow => Ok(Type::Any),
449            },
450            ExprKind::Binary { left, op, right } => {
451                if op == &BinaryOp::Assign
452                    && let ExprKind::Tuple(left_items) | ExprKind::List(left_items) = &left.kind
453                {
454                    if let ExprKind::Tuple(right_items) | ExprKind::List(right_items) = &right.kind {
455                        if left_items.len() != right_items.len() {
456                            return Err(Self::semantic_error(expr.span, format!("多重赋值数量不匹配: 左侧 {} 个,右侧 {} 个", left_items.len(), right_items.len())));
457                        }
458                        for item in right_items {
459                            let _ = self.infer_expr(item)?;
460                        }
461                    } else {
462                        let _ = self.infer_expr(right)?;
463                    }
464                    return Ok(Type::Void);
465                }
466                let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
467                let ty = if op.is_logic() {
468                    Type::Bool
469                } else if op == &BinaryOp::Idx {
470                    let left_ty = self.infer_expr(left)?;
471                    if let Type::Array(elem_ty, _) = left_ty {
472                        (*elem_ty).clone()
473                    } else if let Type::Vec(elem_ty, _) = left_ty {
474                        (*elem_ty).clone()
475                    } else if let Type::List(elem_ty) = left_ty {
476                        (*elem_ty).clone()
477                    } else {
478                        let left_ty = self.symbols.get_type(&left_ty)?;
479                        let right_ty = if right.is_value() || right.is_const() {
480                            let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
481                            if right_value.is_str() {
482                                if left_ty.is_any() {
483                                    return Ok(Type::Any);
484                                }
485                                if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
486                                    return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
487                                }
488                            } else if let Type::Struct { fields, .. } = &left_ty
489                                && let Some(idx) = right_value.as_int()
490                            {
491                                return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
492                            }
493                            right_value.get_type()
494                        } else {
495                            self.infer_expr(right)?
496                        };
497                        if right_ty.is_int() || right_ty.is_uint() {
498                            if left_ty.is_any() {
499                                return Ok(Type::Any);
500                            }
501                            let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
502                            let fn_ty = self.symbols.get_type(&s)?;
503                            return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
504                        }
505                        if left_ty.is_any() {
506                            return Ok(Type::Any);
507                        }
508                        Type::Any
509                    }
510                } else {
511                    let left_ty = self.infer_expr(left)?;
512                    let right_ty = self.infer_expr(right)?;
513                    if op == &BinaryOp::Assign {
514                        if !left_ty.is_any() && right_ty.is_any() { left_ty } else { right_ty }
515                    } else if op.is_assign() && !left_ty.is_any() && right_ty.is_any() {
516                        left_ty
517                    } else {
518                        left_ty + right_ty
519                    }
520                };
521                assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
522                Ok(ty)
523            }
524            ExprKind::Call { obj, params } => {
525                if let ExprKind::Assoc { ty, name } = &obj.kind {
526                    let base_name = match ty {
527                        Type::Ident { name, .. } => name.clone(),
528                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
529                        _ => return Ok(Type::Any),
530                    };
531                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
532                    let generic_args = match ty {
533                        Type::Ident { params, .. } | Type::Symbol { params, .. } => params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>(),
534                        _ => Vec::new(),
535                    };
536                    let mut args = Vec::new();
537                    for p in params {
538                        args.push(self.infer_expr(p)?);
539                    }
540                    self.infer_fn_with_params(id, &args, &generic_args)
541                } else if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
542                    let mut args = Vec::new();
543                    for p in params {
544                        args.push(self.infer_expr(p)?);
545                    }
546                    self.infer_fn_with_params(*id, &args, generic_args)
547                } else if let ExprKind::Generic { obj, params: generic_args } = &obj.kind {
548                    let Type::Symbol { id, .. } = self.infer_expr(obj)? else {
549                        return Ok(Type::Any);
550                    };
551                    let generic_args = generic_args.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>();
552                    let mut args = Vec::new();
553                    for p in params {
554                        args.push(self.infer_expr(p)?);
555                    }
556                    self.infer_fn_with_params(id, &args, &generic_args)
557                } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
558                    let base_name = match ty {
559                        Type::Ident { name, .. } => name.clone(),
560                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
561                        _ => return Ok(Type::Any),
562                    };
563                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
564                    let mut args = vec![self.infer_expr(target)?];
565                    for p in params {
566                        args.push(self.infer_expr(p)?);
567                    }
568                    self.infer_fn(id, &args)
569                } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
570                    let method = self.symbols.get_symbol(*id).ok().and_then(|(name, _)| name.rsplit_once("::").map(|(_, method)| method.to_string()));
571                    if let Some(target) = obj_expr
572                        && let Some(method) = method
573                    {
574                        let target_ty = self.infer_expr(target)?;
575                        if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &target_ty
576                            && let Some(ret_ty) = self.infer_list_method(target, elem_ty, method.as_str(), params)?
577                        {
578                            return Ok(ret_ty);
579                        }
580                    }
581                    let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
582                    for p in params {
583                        args.push(self.infer_expr(p)?);
584                    }
585                    self.infer_fn(*id, &args)
586                } else if let ExprKind::Ident(name) = &obj.kind {
587                    for idx in (self.top()..self.names.len()).rev() {
588                        if self.names[idx].eq(name) && idx < self.tys.len() {
589                            return if let Type::Symbol { id, .. } = &self.tys[idx] {
590                                let id = *id;
591                                let mut args = Vec::new();
592                                for p in params {
593                                    args.push(self.infer_expr(p)?);
594                                }
595                                self.infer_fn(id, &args)
596                            } else {
597                                Ok(Type::Any)
598                            };
599                        }
600                    }
601                    let Ok(id) = self.symbols.get_id(name) else {
602                        return Ok(Type::Any);
603                    };
604                    if !self.symbols.get_symbol(id)?.1.is_fn() {
605                        return Err(Self::semantic_error(obj.span, format!("符号 {} 不是函数", name)));
606                    }
607                    let mut args = Vec::new();
608                    for p in params {
609                        args.push(self.infer_expr(p)?);
610                    }
611                    self.infer_fn(id, &args)
612                } else if obj.is_idx() {
613                    let (target, _, method) = obj.clone().binary().unwrap();
614                    let ty = self.infer_expr(&target)?;
615                    if let Some(method) = self.get_value(&method) {
616                        let method = method.as_str();
617                        if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &ty
618                            && let Some(ret_ty) = self.infer_list_method(&target, elem_ty, method, params)?
619                        {
620                            return Ok(ret_ty);
621                        }
622                        let fn_ty = match self.get_field(&ty, method) {
623                            Ok((_, fn_ty)) => fn_ty,
624                            Err(_) => {
625                                let id = self.symbols.get_id(method)?;
626                                if self.symbols.get_symbol(id)?.1.is_fn() {
627                                    Type::Symbol { id, params: Vec::new() }
628                                } else {
629                                    return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
630                                }
631                            }
632                        };
633                        if let Type::Symbol { id, .. } = fn_ty {
634                            let mut args = vec![ty];
635                            for p in params {
636                                args.push(self.infer_expr(p)?);
637                            }
638                            self.infer_fn(id, &args)
639                        } else {
640                            Ok(fn_ty)
641                        }
642                    } else {
643                        Ok(Type::Any)
644                    }
645                } else if let ExprKind::Var(idx) = &obj.kind {
646                    let idx = self.top() + (*idx as usize);
647                    if idx < self.tys.len()
648                        && let Type::Symbol { id, .. } = self.tys[idx]
649                    {
650                        let mut args = Vec::new();
651                        for p in params {
652                            args.push(self.infer_expr(p)?);
653                        }
654                        self.infer_fn(id, &args)
655                    } else {
656                        Ok(Type::Any)
657                    }
658                } else if obj.is_value() {
659                    Ok(Type::Void)
660                } else {
661                    Ok(Type::Any)
662                }
663            }
664            ExprKind::Typed { ty, .. } => self.symbols.get_type(ty),
665            ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
666            ExprKind::Repeat { value, len } => {
667                let value_ty = self.infer_expr(value)?;
668                let len = self.symbols.get_type(len).unwrap_or_else(|_| len.clone());
669                if let Type::ConstInt(len) = len {
670                    let len = u32::try_from(len).map_err(|_| Self::semantic_error(expr.span, "重复数组长度必须是非负 u32"))?;
671                    Ok(Type::Array(std::rc::Rc::new(value_ty), len))
672                } else {
673                    Ok(Type::ArrayParam(std::rc::Rc::new(value_ty), std::rc::Rc::new(len)))
674                }
675            }
676            ExprKind::List(items) => {
677                if items.is_empty() {
678                    return Ok(Type::list_any());
679                }
680                let mut elem_ty = Type::Any;
681                for item in items {
682                    let item_ty = self.infer_expr(item)?;
683                    elem_ty = if elem_ty.is_any() { item_ty } else { elem_ty + item_ty };
684                }
685                Ok(Type::Array(std::rc::Rc::new(elem_ty), items.len() as u32))
686            }
687            ExprKind::Range { start, stop, .. } => {
688                let start_ty = self.infer_expr(start)?;
689                let stop_ty = self.infer_expr(stop)?;
690                Ok(Self::merge_range_bound_types(start_ty, stop_ty))
691            }
692            _ => Ok(Type::Any),
693        }
694    }
695
696    fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
697        let mut fn_tys = Vec::new();
698        for (i, ty) in tys.iter().enumerate() {
699            if !ty.is_any() {
700                fn_tys.push(ty.clone());
701            } else if let Some(arg_ty) = arg_tys.get(i) {
702                fn_tys.push(self.symbols.get_type(arg_ty)?);
703            } else {
704                fn_tys.push(Type::Any);
705            }
706        }
707        Ok(fn_tys)
708    }
709
710    fn is_optimizable_local_ty(ty: &Type) -> bool {
711        ty.is_bool() || ty.is_native()
712    }
713
714    fn is_optimizable_list_elem_ty(ty: &Type) -> bool {
715        matches!(ty, Type::Bool | Type::U8 | Type::I8 | Type::U16 | Type::I16 | Type::U32 | Type::I32 | Type::F32 | Type::U64 | Type::I64 | Type::F64 | Type::Str)
716    }
717
718    fn local_type_hint_at(&self, pos: usize) -> Option<Type> {
719        let ty = self.tys.get(pos)?;
720        match ty {
721            Type::List(_) => self.list_elem_states.get(pos).cloned().flatten().and_then(|state| {
722                if let ListElemState::Known(elem_ty) = state
723                    && Self::is_optimizable_list_elem_ty(&elem_ty)
724                {
725                    Some(Type::List(std::rc::Rc::new(elem_ty)))
726                } else {
727                    None
728                }
729            }),
730            ty if Self::is_optimizable_local_ty(ty) => Some(ty.clone()),
731            _ => None,
732        }
733    }
734
735    fn collect_local_type_hints(&self) -> Vec<Option<Type>> {
736        (self.top()..self.tys.len()).map(|pos| self.local_type_hint_at(pos)).collect()
737    }
738
739    fn set_local_type_hints(&mut self, id: u32, generic_args: &[Type], fn_tys: &[Type], hints: Vec<Option<Type>>) {
740        let items = self.local_type_hints.entry(id).or_default();
741        if let Some(item) = items.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys) {
742            item.2 = hints;
743        } else {
744            items.push((generic_args.to_vec(), fn_tys.to_vec(), hints));
745        }
746    }
747
748    pub fn inferred_local_type_hints(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Vec<Option<Type>> {
749        self.local_type_hints.get(&id).and_then(|items| items.iter().find(|item| item.0 == generic_args && item.1 == fn_tys)).map(|item| item.2.clone()).unwrap_or_default()
750    }
751
752    pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
753        self.infer_fn_with_params(id, arg_tys, &[])
754    }
755
756    pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
757        let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
758        if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
759            if let Type::Fn { tys, ret: _ } = ty {
760                let resolved_generic_args = crate::resolve_generic_args_from_types(&generic_params, &tys, arg_tys, generic_args)?;
761                let generic_args = resolved_generic_args.as_slice();
762                let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
763                let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
764                let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
765                let body = if generic_params.is_empty() {
766                    body
767                } else {
768                    let mut compile_tys = tys.clone();
769                    let mut compile_cap = cap.clone();
770                    let saved_state = self.take_local_state();
771                    if let Some((module, _)) = name.split_once("::") {
772                        self.symbols.push_module_scope(module.into());
773                    }
774                    let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
775                    if name.contains("::") {
776                        self.symbols.pop_module_scope();
777                    }
778                    self.restore_local_state(saved_state);
779                    Stmt::new(StmtKind::Block(compiled?), Span::default())
780                };
781                if let Some(fns) = self.fns.get_mut(&id) {
782                    for f in fns.iter() {
783                        if f.0 == generic_args && f.1 == fn_tys {
784                            return match &f.2 {
785                                FnInferRet::Done(ret_ty) => self.symbols.get_type(ret_ty),
786                                FnInferRet::Pending(seed) => seed.as_ref().map(|ty| self.symbols.get_type(ty)).unwrap_or_else(|| {
787                                    // 递归自调用且种子为空:尝试从函数体 base case 查找返回类型
788                                    if self.infer_stack.iter().any(|(sid, sargs, _)| *sid == id && sargs == generic_args) {
789                                        if let Some(base_ty) = self.try_find_base_return_ty(&body) {
790                                            return self.symbols.get_type(&base_ty);
791                                        }
792                                    }
793                                    Ok(Type::Any)
794                                }),
795                            };
796                        }
797                    }
798                    fns.push((generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None)));
799                } else {
800                    self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None))]);
801                }
802                // 递归函数:预扫描 base case 返回类型作为种子
803                if self.pending_return_seed(id, generic_args, &fn_tys).is_none() {
804                    if let Some(base_ty) = self.try_find_base_return_ty_with_scope(&body, id, &name, &args, &fn_tys) {
805                        if let Some(fns) = self.fns.get_mut(&id) {
806                            if let Some(item) = fns.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys)
807                                && let FnInferRet::Pending(seed) = &mut item.2
808                                && seed.is_none()
809                            {
810                                *seed = Some(base_ty);
811                            }
812                        }
813                    }
814                }
815                let mut ret_ty = None;
816                let mut local_type_hints = Vec::new();
817                for _ in 0..4 {
818                    let before_seed = self.pending_return_seed(id, generic_args, &fn_tys);
819                    let saved_state = self.take_local_state();
820                    self.frames.push(0);
821                    for (arg, ty) in args.iter().zip(fn_tys.iter()) {
822                        self.add_name(arg.clone());
823                        self.add_ty(ty.clone());
824                    }
825                    for c in cap.vars.iter() {
826                        if let Some((name, ty)) = cap.names.get(*c) {
827                            self.add_name(name.clone());
828                            self.add_ty(ty.clone());
829                        } else {
830                            self.add_name("".into());
831                            self.add_ty(Type::Any);
832                        }
833                    }
834                    self.infer_stack.push((id, generic_args.to_vec(), fn_tys.clone()));
835                    let pass_ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
836                    self.infer_stack.pop();
837                    let pass_local_type_hints = self.collect_local_type_hints();
838                    self.restore_local_state(saved_state);
839                    let pass_ret_ty = match pass_ret_ty {
840                        Ok(pass_ret_ty) => self.symbols.get_type(&pass_ret_ty).unwrap_or(pass_ret_ty),
841                        Err(err) => {
842                            log::error!("infer_fn {} failed: {:?}", name, err);
843                            let should_remove = self
844                                .fns
845                                .get_mut(&id)
846                                .map(|fns| {
847                                    fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || !matches!(item.2, FnInferRet::Pending(_)));
848                                    fns.is_empty()
849                                })
850                                .unwrap_or(false);
851                            if should_remove {
852                                self.fns.remove(&id);
853                            }
854                            return Err(err);
855                        }
856                    };
857                    if !pass_ret_ty.is_any() {
858                        self.update_pending_return_seed(&pass_ret_ty);
859                        ret_ty = Some(pass_ret_ty.clone());
860                    } else if ret_ty.is_none() {
861                        ret_ty = Some(pass_ret_ty);
862                    }
863                    local_type_hints = pass_local_type_hints;
864                    let after_seed = self.pending_return_seed(id, generic_args, &fn_tys);
865                    if before_seed == after_seed {
866                        break;
867                    }
868                }
869                let ret_ty = ret_ty.unwrap_or(Type::Any);
870                self.fns.get_mut(&id).map(|f| {
871                    f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = FnInferRet::Done(ret_ty.clone()));
872                });
873                self.set_local_type_hints(id, generic_args, &fn_tys, local_type_hints);
874                if generic_args.is_empty()
875                    && let Some((_, Symbol::Fn { ty: Type::Fn { ret, .. }, .. })) = self.symbols.get_symbol_mut(id)
876                    && ret.is_any()
877                {
878                    *ret = std::rc::Rc::new(ret_ty.clone());
879                }
880                Ok(ret_ty)
881            } else {
882                Ok(Type::Any)
883            }
884        } else if let Symbol::Native(f) = s {
885            if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
886        } else if matches!(s, Symbol::Null) {
887            Ok(Type::Any)
888        } else {
889            Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
890        }
891    }
892
893    pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
894        match &stmt.kind {
895            StmtKind::Expr(expr, close) => {
896                if !close {
897                    self.infer_expr(expr)
898                } else {
899                    self.infer_expr(expr)?;
900                    Ok(Type::Void)
901                }
902            }
903            StmtKind::Return(expr) => {
904                if let Some(e) = expr {
905                    self.infer_expr(e)
906                } else {
907                    Ok(Type::Void)
908                }
909            }
910            StmtKind::Block(stmts) => {
911                for (idx, stmt) in stmts.iter().enumerate() {
912                    let ty = self.infer_stmt(stmt)?;
913                    if stmt.is_return() || idx == stmts.len() - 1 {
914                        return Ok(ty);
915                    }
916                }
917                Ok(Type::Void)
918            }
919            StmtKind::If { then_body, else_body, .. } => {
920                let then_ty = self.infer_stmt(then_body)?;
921                if let Some(e) = else_body {
922                    let else_ty = self.infer_stmt(e)?;
923                    if then_ty != else_ty {
924                        log::debug!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
925                        return Self::merge_return_type(stmt.span, Some(then_ty), else_ty);
926                    }
927                }
928                if else_body.is_none() {
929                    return Ok(Type::Void);
930                }
931                Ok(then_ty)
932            }
933            StmtKind::While { cond, body } => {
934                let cond_ty = self.infer_expr(cond)?;
935                if cond_ty != Type::Bool {
936                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
937                }
938                self.infer_stmt(body)
939            }
940            StmtKind::For { pat, range, body } => {
941                let ty = self.for_pattern_ty(range)?;
942                self.add_pattern_bindings_for_infer(pat, ty)?;
943                self.infer_stmt(body)
944            }
945            StmtKind::Let { pat, value } => {
946                let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
947                self.add_pattern_bindings_for_infer(pat, expr_ty)?;
948                Ok(Type::Void)
949            }
950            _ => Ok(Type::Void),
951        }
952    }
953}