Skip to main content

compiler/
infer.rs

1use super::{Compiler, FnInferRet, ListElemState, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, Pattern, PatternKind, Span, Stmt, StmtKind, UnaryOp};
5
6#[derive(Clone)]
7struct ReturnInfo {
8    ty: Type,
9    shape: Option<Type>,
10}
11
12impl Compiler {
13    fn current_infer_key(&self) -> Option<(u32, Vec<Type>, Vec<Type>)> {
14        self.infer_stack.last().cloned()
15    }
16
17    fn pending_return_seed(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Option<Type> {
18        self.fns.get(&id).and_then(|fns| {
19            fns.iter().find_map(|item| {
20                if item.0 == generic_args
21                    && item.1 == fn_tys
22                    && let FnInferRet::Pending(seed) = &item.2
23                {
24                    seed.clone()
25                } else {
26                    None
27                }
28            })
29        })
30    }
31
32    fn update_pending_return_seed(&mut self, ty: &Type) {
33        if ty.is_any() {
34            return;
35        }
36        let Some((id, generic_args, fn_tys)) = self.current_infer_key() else {
37            return;
38        };
39        let Some(fns) = self.fns.get_mut(&id) else {
40            return;
41        };
42        if let Some(item) = fns.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys)
43            && let FnInferRet::Pending(seed) = &mut item.2
44        {
45            let next = seed.take().map(|prev| prev + ty.clone()).unwrap_or_else(|| ty.clone());
46            *seed = Some(next);
47        }
48    }
49
50    fn add_pattern_bindings_for_infer(&mut self, pat: &Pattern, expr_ty: Type) -> Result<()> {
51        match &pat.kind {
52            PatternKind::Ident { name, ty } => {
53                let annotated_ty = self.symbols.get_type(ty)?;
54                self.add_name(name.clone());
55                self.add_ty(if annotated_ty.is_any() { expr_ty } else { annotated_ty });
56            }
57            PatternKind::Var { idx, .. } => self.set_ty(*idx, expr_ty),
58            PatternKind::Tuple(pats) => {
59                if let Type::Tuple(tys) = expr_ty {
60                    for (pat, ty) in pats.iter().zip(tys) {
61                        self.add_pattern_bindings_for_infer(pat, ty)?;
62                    }
63                } else {
64                    for pat in pats {
65                        self.add_pattern_bindings_for_infer(pat, Type::Any)?;
66                    }
67                }
68            }
69            PatternKind::List { elems, .. } => {
70                for pat in elems {
71                    self.add_pattern_bindings_for_infer(pat, Type::Any)?;
72                }
73            }
74            PatternKind::Wildcard => {
75                self.add_name("".into());
76                self.add_ty(expr_ty);
77            }
78            PatternKind::Literal(_) | PatternKind::Member(_, _) | PatternKind::Idx(_, _) => {}
79        }
80        Ok(())
81    }
82
83    fn for_pattern_ty(&mut self, range: &Expr) -> Result<Type> {
84        if matches!(range.kind, ExprKind::Range { .. }) {
85            return self.infer_range_expr(range);
86        }
87        Ok(match self.infer_expr(range)? {
88            Type::Array(elem_ty, _) | Type::Vec(elem_ty, _) | Type::List(elem_ty) => elem_ty.as_ref().clone(),
89            _ => Type::Any,
90        })
91    }
92
93    fn infer_range_expr(&mut self, range: &Expr) -> Result<Type> {
94        let ExprKind::Range { start, stop, .. } = &range.kind else {
95            return self.infer_expr(range);
96        };
97        let start_ty = self.infer_expr(start)?;
98        let stop_ty = self.infer_expr(stop)?;
99        Ok(Self::merge_range_bound_types(start_ty, stop_ty))
100    }
101
102    fn merge_range_bound_types(start_ty: Type, stop_ty: Type) -> Type {
103        if start_ty.is_any() {
104            stop_ty
105        } else if stop_ty.is_any() {
106            start_ty
107        } else if start_ty == Type::I32 && stop_ty.is_uint() {
108            stop_ty
109        } else if stop_ty == Type::I32 && start_ty.is_uint() {
110            start_ty
111        } else {
112            start_ty + stop_ty
113        }
114    }
115
116    fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
117        match left {
118            Some(left) if left == right => Ok(left),
119            Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
120            Some(left) if left.is_any() || right.is_any() => Ok(Type::Any),
121            Some(left) => Ok(left + right),
122            None => Ok(right),
123        }
124    }
125
126    fn return_shape(&self, expr: &Expr, ty: &Type) -> Option<Type> {
127        if !ty.is_any() {
128            return match ty {
129                Type::Struct { .. } => Some(ty.clone()),
130                Type::Map => Some(Type::Map),
131                Type::List(elem) | Type::Array(elem, _) => Some(Type::List(elem.clone())),
132                _ => None,
133            };
134        }
135        match &expr.kind {
136            ExprKind::List(_) | ExprKind::Tuple(_) => Some(Type::list_any()),
137            ExprKind::Dict(_) => Some(Type::Map),
138            ExprKind::Value(value) => Self::dynamic_return_shape(value.get_type()),
139            ExprKind::Const(idx) => self.consts.get(*idx).and_then(|value| Self::dynamic_return_shape(value.get_type())),
140            ExprKind::Typed { ty, .. } => Some(ty.clone()),
141            _ => None,
142        }
143    }
144
145    fn dynamic_return_shape(ty: Type) -> Option<Type> {
146        match ty {
147            Type::Map => Some(Type::Map),
148            Type::List(elem) => Some(Type::List(elem)),
149            Type::Array(elem, _) => Some(Type::List(elem)),
150            _ => None,
151        }
152    }
153
154    fn local_var_idx_for_expr(&self, expr: &Expr) -> Option<u32> {
155        match &expr.kind {
156            ExprKind::Var(idx) => Some(*idx),
157            ExprKind::Ident(name) => (self.top()..self.names.len()).rev().find(|idx| self.names[*idx].eq(name)).map(|idx| (idx - self.top()) as u32),
158            _ => None,
159        }
160    }
161
162    fn infer_list_method(&mut self, target: &Expr, elem_ty: &Type, method: &str, params: &[Expr]) -> Result<Option<Type>> {
163        match method {
164            "get_idx" | "pop" => Ok(Some(match self.local_var_idx_for_expr(target).and_then(|idx| self.list_elem_state(idx)) {
165                Some(ListElemState::Known(ty)) => ty,
166                Some(ListElemState::Unknown | ListElemState::Mixed) => Type::Any,
167                None => elem_ty.clone(),
168            })),
169            "push" => {
170                let pushed_ty = params
171                    .first()
172                    .map(|param| {
173                        if let Some(value) = self.get_value(param)
174                            && (value.is_str() || value.is_native())
175                        {
176                            Ok(value.get_type())
177                        } else {
178                            self.infer_expr(param)
179                        }
180                    })
181                    .transpose()?
182                    .unwrap_or(Type::Any);
183                if let Some(idx) = self.local_var_idx_for_expr(target) {
184                    let state = self.list_elem_state(idx).unwrap_or_else(|| if elem_ty.is_any() { ListElemState::Unknown } else { ListElemState::Known(elem_ty.clone()) });
185                    let next_state = match state {
186                        ListElemState::Unknown if pushed_ty.is_any() => ListElemState::Mixed,
187                        ListElemState::Unknown => ListElemState::Known(pushed_ty),
188                        ListElemState::Known(_) if pushed_ty.is_any() => ListElemState::Mixed,
189                        ListElemState::Known(prev) => {
190                            let merged = if prev == pushed_ty {
191                                prev
192                            } else if (prev.is_int() || prev.is_uint() || prev.is_float()) && (pushed_ty.is_int() || pushed_ty.is_uint() || pushed_ty.is_float()) {
193                                prev + pushed_ty
194                            } else {
195                                Type::Any
196                            };
197                            if merged.is_any() { ListElemState::Mixed } else { ListElemState::Known(merged) }
198                        }
199                        ListElemState::Mixed => ListElemState::Mixed,
200                    };
201                    let next_elem = if let ListElemState::Known(ty) = &next_state { ty.clone() } else { Type::Any };
202                    self.set_ty(idx, Type::List(std::rc::Rc::new(next_elem)));
203                    self.set_list_elem_state(idx, Some(next_state));
204                }
205                Ok(Some(Type::Void))
206            }
207            "len" => Ok(Some(Type::I32)),
208            "is_list" | "is_null" => Ok(Some(Type::Bool)),
209            _ => Ok(None),
210        }
211    }
212
213    fn infer_return_expr(&mut self, expr: &Expr) -> Result<ReturnInfo> {
214        let ty = self.infer_expr(expr)?;
215        let shape = self.return_shape(expr, &ty);
216        let ty = if matches!(shape, Some(Type::Map | Type::List(_))) { Type::Any } else { ty };
217        Ok(ReturnInfo { ty, shape })
218    }
219
220    fn merge_return_info(span: Span, left: Option<ReturnInfo>, right: ReturnInfo) -> Result<ReturnInfo> {
221        let Some(left) = left else {
222            return Ok(right);
223        };
224        if let (Some(left_shape), Some(right_shape)) = (&left.shape, &right.shape)
225            && left_shape != right_shape
226        {
227            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, right_shape)));
228        }
229        if let Some(left_shape) = &left.shape
230            && left_shape.is_struct()
231            && right.ty.is_any()
232            && right.shape.is_none()
233        {
234            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, Type::Any)));
235        }
236        if let Some(right_shape) = &right.shape
237            && right_shape.is_struct()
238            && left.ty.is_any()
239            && left.shape.is_none()
240        {
241            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", Type::Any, right_shape)));
242        }
243        let ty = Self::merge_return_type(span, Some(left.ty), right.ty)?;
244        Ok(ReturnInfo { ty, shape: left.shape.or(right.shape) })
245    }
246
247    fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
248        self.infer_returns(stmt, true).map(|(info, _)| info.map(|info| info.ty))
249    }
250
251    pub(crate) fn check_return_type(&mut self, stmt: &Stmt) -> Result<()> {
252        self.infer_returns(stmt, true).map(|_| ())
253    }
254
255    fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<ReturnInfo>, bool)> {
256        match &stmt.kind {
257            StmtKind::Return(Some(expr)) => Ok((Some(self.infer_return_expr(expr)?), true)),
258            StmtKind::Return(None) => Ok((Some(ReturnInfo { ty: Type::Void, shape: Some(Type::Void) }), true)),
259            StmtKind::Block(stmts) => {
260                let mut ret = None;
261                for (idx, stmt) in stmts.iter().enumerate() {
262                    let (info, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
263                    if let Some(info) = info {
264                        self.update_pending_return_seed(&info.ty);
265                        ret = Some(Self::merge_return_info(stmt.span, ret, info)?);
266                        if let Some(ret) = &ret {
267                            self.update_pending_return_seed(&ret.ty);
268                        }
269                    }
270                    if always_returns {
271                        return Ok((ret, true));
272                    }
273                }
274                Ok((ret, false))
275            }
276            StmtKind::If { cond, then_body, else_body } => {
277                let cond_ty = self.infer_expr(cond)?;
278                if cond_ty != Type::Bool {
279                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
280                }
281                let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
282                if let Some(ret) = &ret {
283                    self.update_pending_return_seed(&ret.ty);
284                }
285                let else_returns = if let Some(body) = else_body {
286                    let (else_ty, else_returns) = self.infer_returns(body, tail)?;
287                    if let Some(info) = else_ty {
288                        self.update_pending_return_seed(&info.ty);
289                        ret = Some(Self::merge_return_info(body.span, ret, info)?);
290                        if let Some(ret) = &ret {
291                            self.update_pending_return_seed(&ret.ty);
292                        }
293                    }
294                    else_returns
295                } else {
296                    false
297                };
298                Ok((ret, then_returns && else_returns))
299            }
300            StmtKind::While { cond, body } => {
301                let cond_ty = self.infer_expr(cond)?;
302                if cond_ty != Type::Bool {
303                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
304                }
305                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
306            }
307            StmtKind::Loop(body) => self.infer_returns(body, false),
308            StmtKind::For { pat, range, body } => {
309                let ty = self.for_pattern_ty(range)?;
310                self.add_pattern_bindings_for_infer(pat, ty)?;
311                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
312            }
313            StmtKind::Let { .. } => {
314                self.infer_stmt(stmt)?;
315                Ok((None, false))
316            }
317            StmtKind::Expr(expr, close) => {
318                let info = self.infer_return_expr(expr)?;
319                Ok(if *close || !tail { (None, false) } else { (Some(info), true) })
320            }
321            _ => {
322                self.infer_stmt(stmt)?;
323                Ok((None, false))
324            }
325        }
326    }
327
328    pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
329        match &expr.kind {
330            ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
331            ExprKind::Value(v) if v.is_list() => Ok(v.get_type()),
332            ExprKind::Value(v) if v.is_map() => Ok(Type::Any),
333            ExprKind::Value(v) => Ok(v.get_type()),
334            ExprKind::Const(idx) => Ok(match self.consts.get(*idx) {
335                Some(value) if value.is_str() => Type::Str,
336                Some(value) if value.is_list() && value.len() == 0 => Type::list_any(),
337                _ => Type::Any,
338            }),
339            ExprKind::Var(idx) => {
340                let idx = self.top() + (*idx as usize);
341                if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
342            }
343            ExprKind::Ident(ident) => {
344                for idx in (self.top()..self.names.len()).rev() {
345                    if self.names[idx].eq(ident) && idx < self.tys.len() {
346                        return self.symbols.get_type(&self.tys[idx]);
347                    }
348                }
349                let id = self.symbols.get_id(ident).map_err(|_| Self::semantic_error(expr.span, format!("未找到标识符 {}", ident)))?;
350                match self.symbols.get_symbol(id)?.1 {
351                    Symbol::Const { ty, .. } => Ok(ty.clone()),
352                    Symbol::Static { ty, .. } => Ok(ty.clone()),
353                    Symbol::Struct(ty, _) => Ok(ty.clone()),
354                    Symbol::Fn { .. } => Ok(Type::Symbol { id, params: Vec::new() }),
355                    Symbol::Native(ty) => Ok(ty.clone()),
356                    s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
357                }
358            }
359            ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
360                Symbol::Const { ty, .. } => Ok(ty.clone()),
361                Symbol::Static { ty, .. } => Ok(ty.clone()),
362                Symbol::Struct(ty, _) => Ok(ty.clone()),
363                Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
364                Symbol::Native(ty) => Ok(ty.clone()),
365                s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
366            },
367            ExprKind::Generic { obj, params } => {
368                let params = params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect();
369                match self.infer_expr(obj)? {
370                    Type::Symbol { id, .. } => Ok(Type::Symbol { id, params }),
371                    _ => Ok(Type::Any),
372                }
373            }
374            ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
375            ExprKind::Unary { op, value } => match op {
376                UnaryOp::Not => {
377                    let ty = self.infer_expr(value.as_ref())?;
378                    if ty.is_int() || ty.is_uint() { Ok(ty) } else { Ok(Type::Bool) }
379                }
380                UnaryOp::Neg => self.infer_expr(value.as_ref()),
381                UnaryOp::Unknow => Ok(Type::Any),
382            },
383            ExprKind::Binary { left, op, right } => {
384                if op == &BinaryOp::Assign
385                    && let ExprKind::Tuple(left_items) | ExprKind::List(left_items) = &left.kind
386                {
387                    if let ExprKind::Tuple(right_items) | ExprKind::List(right_items) = &right.kind {
388                        if left_items.len() != right_items.len() {
389                            return Err(Self::semantic_error(expr.span, format!("多重赋值数量不匹配: 左侧 {} 个,右侧 {} 个", left_items.len(), right_items.len())));
390                        }
391                        for item in right_items {
392                            let _ = self.infer_expr(item)?;
393                        }
394                    } else {
395                        let _ = self.infer_expr(right)?;
396                    }
397                    return Ok(Type::Void);
398                }
399                let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
400                let ty = if op.is_logic() {
401                    Type::Bool
402                } else if op == &BinaryOp::Idx {
403                    let left_ty = self.infer_expr(left)?;
404                    if let Type::Array(elem_ty, _) = left_ty {
405                        (*elem_ty).clone()
406                    } else if let Type::Vec(elem_ty, _) = left_ty {
407                        (*elem_ty).clone()
408                    } else if let Type::List(elem_ty) = left_ty {
409                        (*elem_ty).clone()
410                    } else {
411                        let left_ty = self.symbols.get_type(&left_ty)?;
412                        let right_ty = if right.is_value() || right.is_const() {
413                            let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
414                            if right_value.is_str() {
415                                if left_ty.is_any() {
416                                    return Ok(Type::Any);
417                                }
418                                if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
419                                    return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
420                                }
421                            } else if let Type::Struct { fields, .. } = &left_ty
422                                && let Some(idx) = right_value.as_int()
423                            {
424                                return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
425                            }
426                            right_value.get_type()
427                        } else {
428                            self.infer_expr(right)?
429                        };
430                        if right_ty.is_int() || right_ty.is_uint() {
431                            if left_ty.is_any() {
432                                return Ok(Type::Any);
433                            }
434                            let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
435                            let fn_ty = self.symbols.get_type(&s)?;
436                            return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
437                        }
438                        if left_ty.is_any() {
439                            return Ok(Type::Any);
440                        }
441                        Type::Any
442                    }
443                } else {
444                    let left_ty = self.infer_expr(left)?;
445                    let right_ty = self.infer_expr(right)?;
446                    if op == &BinaryOp::Assign {
447                        if !left_ty.is_any() && right_ty.is_any() { left_ty } else { right_ty }
448                    } else if op.is_assign() && !left_ty.is_any() && right_ty.is_any() {
449                        left_ty
450                    } else {
451                        left_ty + right_ty
452                    }
453                };
454                assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
455                Ok(ty)
456            }
457            ExprKind::Call { obj, params } => {
458                if let ExprKind::Assoc { ty, name } = &obj.kind {
459                    let base_name = match ty {
460                        Type::Ident { name, .. } => name.clone(),
461                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
462                        _ => return Ok(Type::Any),
463                    };
464                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
465                    let generic_args = match ty {
466                        Type::Ident { params, .. } | Type::Symbol { params, .. } => params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>(),
467                        _ => Vec::new(),
468                    };
469                    let mut args = Vec::new();
470                    for p in params {
471                        args.push(self.infer_expr(p)?);
472                    }
473                    self.infer_fn_with_params(id, &args, &generic_args)
474                } else if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
475                    let mut args = Vec::new();
476                    for p in params {
477                        args.push(self.infer_expr(p)?);
478                    }
479                    self.infer_fn_with_params(*id, &args, generic_args)
480                } else if let ExprKind::Generic { obj, params: generic_args } = &obj.kind {
481                    let Type::Symbol { id, .. } = self.infer_expr(obj)? else {
482                        return Ok(Type::Any);
483                    };
484                    let generic_args = generic_args.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>();
485                    let mut args = Vec::new();
486                    for p in params {
487                        args.push(self.infer_expr(p)?);
488                    }
489                    self.infer_fn_with_params(id, &args, &generic_args)
490                } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
491                    let base_name = match ty {
492                        Type::Ident { name, .. } => name.clone(),
493                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
494                        _ => return Ok(Type::Any),
495                    };
496                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
497                    let mut args = vec![self.infer_expr(target)?];
498                    for p in params {
499                        args.push(self.infer_expr(p)?);
500                    }
501                    self.infer_fn(id, &args)
502                } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
503                    let method = self.symbols.get_symbol(*id).ok().and_then(|(name, _)| name.rsplit_once("::").map(|(_, method)| method.to_string()));
504                    if let Some(target) = obj_expr
505                        && let Some(method) = method
506                    {
507                        let target_ty = self.infer_expr(target)?;
508                        if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &target_ty
509                            && let Some(ret_ty) = self.infer_list_method(target, elem_ty, method.as_str(), params)?
510                        {
511                            return Ok(ret_ty);
512                        }
513                    }
514                    let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
515                    for p in params {
516                        args.push(self.infer_expr(p)?);
517                    }
518                    self.infer_fn(*id, &args)
519                } else if let ExprKind::Ident(name) = &obj.kind {
520                    for idx in (self.top()..self.names.len()).rev() {
521                        if self.names[idx].eq(name) && idx < self.tys.len() {
522                            return if let Type::Symbol { id, .. } = &self.tys[idx] {
523                                let id = *id;
524                                let mut args = Vec::new();
525                                for p in params {
526                                    args.push(self.infer_expr(p)?);
527                                }
528                                self.infer_fn(id, &args)
529                            } else {
530                                Ok(Type::Any)
531                            };
532                        }
533                    }
534                    let Ok(id) = self.symbols.get_id(name) else {
535                        return Ok(Type::Any);
536                    };
537                    if !self.symbols.get_symbol(id)?.1.is_fn() {
538                        return Err(Self::semantic_error(obj.span, format!("符号 {} 不是函数", name)));
539                    }
540                    let mut args = Vec::new();
541                    for p in params {
542                        args.push(self.infer_expr(p)?);
543                    }
544                    self.infer_fn(id, &args)
545                } else if obj.is_idx() {
546                    let (target, _, method) = obj.clone().binary().unwrap();
547                    let ty = self.infer_expr(&target)?;
548                    if let Some(method) = self.get_value(&method) {
549                        let method = method.as_str();
550                        if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &ty
551                            && let Some(ret_ty) = self.infer_list_method(&target, elem_ty, method, params)?
552                        {
553                            return Ok(ret_ty);
554                        }
555                        let fn_ty = match self.get_field(&ty, method) {
556                            Ok((_, fn_ty)) => fn_ty,
557                            Err(_) => {
558                                let id = self.symbols.get_id(method)?;
559                                if self.symbols.get_symbol(id)?.1.is_fn() {
560                                    Type::Symbol { id, params: Vec::new() }
561                                } else {
562                                    return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
563                                }
564                            }
565                        };
566                        if let Type::Symbol { id, .. } = fn_ty {
567                            let mut args = vec![ty];
568                            for p in params {
569                                args.push(self.infer_expr(p)?);
570                            }
571                            self.infer_fn(id, &args)
572                        } else {
573                            Ok(fn_ty)
574                        }
575                    } else {
576                        Ok(Type::Any)
577                    }
578                } else if let ExprKind::Var(idx) = &obj.kind {
579                    let idx = self.top() + (*idx as usize);
580                    if idx < self.tys.len()
581                        && let Type::Symbol { id, .. } = self.tys[idx]
582                    {
583                        let mut args = Vec::new();
584                        for p in params {
585                            args.push(self.infer_expr(p)?);
586                        }
587                        self.infer_fn(id, &args)
588                    } else {
589                        Ok(Type::Any)
590                    }
591                } else if obj.is_value() {
592                    Ok(Type::Void)
593                } else {
594                    Ok(Type::Any)
595                }
596            }
597            ExprKind::Typed { ty, .. } => self.symbols.get_type(ty),
598            ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
599            ExprKind::Repeat { value, len } => {
600                let value_ty = self.infer_expr(value)?;
601                let len = self.symbols.get_type(len).unwrap_or_else(|_| len.clone());
602                if let Type::ConstInt(len) = len {
603                    let len = u32::try_from(len).map_err(|_| Self::semantic_error(expr.span, "重复数组长度必须是非负 u32"))?;
604                    Ok(Type::Array(std::rc::Rc::new(value_ty), len))
605                } else {
606                    Ok(Type::ArrayParam(std::rc::Rc::new(value_ty), std::rc::Rc::new(len)))
607                }
608            }
609            ExprKind::List(items) => {
610                if items.is_empty() {
611                    return Ok(Type::list_any());
612                }
613                let mut elem_ty = Type::Any;
614                for item in items {
615                    let item_ty = self.infer_expr(item)?;
616                    elem_ty = if elem_ty.is_any() { item_ty } else { elem_ty + item_ty };
617                }
618                Ok(Type::Array(std::rc::Rc::new(elem_ty), items.len() as u32))
619            }
620            ExprKind::Range { start, stop, .. } => {
621                let start_ty = self.infer_expr(start)?;
622                let stop_ty = self.infer_expr(stop)?;
623                Ok(Self::merge_range_bound_types(start_ty, stop_ty))
624            }
625            _ => Ok(Type::Any),
626        }
627    }
628
629    fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
630        let mut fn_tys = Vec::new();
631        for (i, ty) in tys.iter().enumerate() {
632            if !ty.is_any() {
633                fn_tys.push(ty.clone());
634            } else if let Some(arg_ty) = arg_tys.get(i) {
635                fn_tys.push(self.symbols.get_type(arg_ty)?);
636            } else {
637                fn_tys.push(Type::Any);
638            }
639        }
640        Ok(fn_tys)
641    }
642
643    fn is_optimizable_local_ty(ty: &Type) -> bool {
644        ty.is_bool() || ty.is_native()
645    }
646
647    fn is_optimizable_list_elem_ty(ty: &Type) -> bool {
648        matches!(ty, Type::Bool | Type::U8 | Type::I8 | Type::U16 | Type::I16 | Type::U32 | Type::I32 | Type::F32 | Type::U64 | Type::I64 | Type::F64 | Type::Str)
649    }
650
651    fn local_type_hint_at(&self, pos: usize) -> Option<Type> {
652        let ty = self.tys.get(pos)?;
653        match ty {
654            Type::List(_) => self.list_elem_states.get(pos).cloned().flatten().and_then(|state| {
655                if let ListElemState::Known(elem_ty) = state
656                    && Self::is_optimizable_list_elem_ty(&elem_ty)
657                {
658                    Some(Type::List(std::rc::Rc::new(elem_ty)))
659                } else {
660                    None
661                }
662            }),
663            ty if Self::is_optimizable_local_ty(ty) => Some(ty.clone()),
664            _ => None,
665        }
666    }
667
668    fn collect_local_type_hints(&self) -> Vec<Option<Type>> {
669        (self.top()..self.tys.len()).map(|pos| self.local_type_hint_at(pos)).collect()
670    }
671
672    fn set_local_type_hints(&mut self, id: u32, generic_args: &[Type], fn_tys: &[Type], hints: Vec<Option<Type>>) {
673        let items = self.local_type_hints.entry(id).or_default();
674        if let Some(item) = items.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys) {
675            item.2 = hints;
676        } else {
677            items.push((generic_args.to_vec(), fn_tys.to_vec(), hints));
678        }
679    }
680
681    pub fn inferred_local_type_hints(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Vec<Option<Type>> {
682        self.local_type_hints.get(&id).and_then(|items| items.iter().find(|item| item.0 == generic_args && item.1 == fn_tys)).map(|item| item.2.clone()).unwrap_or_default()
683    }
684
685    pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
686        self.infer_fn_with_params(id, arg_tys, &[])
687    }
688
689    pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
690        let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
691        if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
692            if let Type::Fn { tys, ret: _ } = ty {
693                let resolved_generic_args = crate::resolve_generic_args_from_types(&generic_params, &tys, arg_tys, generic_args)?;
694                let generic_args = resolved_generic_args.as_slice();
695                let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
696                let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
697                let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
698                let body = if generic_params.is_empty() {
699                    body
700                } else {
701                    let mut compile_tys = tys.clone();
702                    let mut compile_cap = cap.clone();
703                    let saved_state = self.take_local_state();
704                    if let Some((module, _)) = name.split_once("::") {
705                        self.symbols.push_module_scope(module.into());
706                    }
707                    let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
708                    if name.contains("::") {
709                        self.symbols.pop_module_scope();
710                    }
711                    self.restore_local_state(saved_state);
712                    Stmt::new(StmtKind::Block(compiled?), Span::default())
713                };
714                if let Some(fns) = self.fns.get_mut(&id) {
715                    for f in fns.iter() {
716                        if f.0 == generic_args && f.1 == fn_tys {
717                            return match &f.2 {
718                                FnInferRet::Done(ret_ty) => self.symbols.get_type(ret_ty),
719                                FnInferRet::Pending(seed) => seed.as_ref().map(|ty| self.symbols.get_type(ty)).unwrap_or(Ok(Type::Any)),
720                            };
721                        }
722                    }
723                    fns.push((generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None)));
724                } else {
725                    self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None))]);
726                }
727                let mut ret_ty = None;
728                let mut local_type_hints = Vec::new();
729                for _ in 0..4 {
730                    let before_seed = self.pending_return_seed(id, generic_args, &fn_tys);
731                    let saved_state = self.take_local_state();
732                    self.frames.push(0);
733                    for (arg, ty) in args.iter().zip(fn_tys.iter()) {
734                        self.add_name(arg.clone());
735                        self.add_ty(ty.clone());
736                    }
737                    for c in cap.vars.iter() {
738                        if let Some((name, ty)) = cap.names.get(*c) {
739                            self.add_name(name.clone());
740                            self.add_ty(ty.clone());
741                        } else {
742                            self.add_name("".into());
743                            self.add_ty(Type::Any);
744                        }
745                    }
746                    self.infer_stack.push((id, generic_args.to_vec(), fn_tys.clone()));
747                    let pass_ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
748                    self.infer_stack.pop();
749                    let pass_local_type_hints = self.collect_local_type_hints();
750                    self.restore_local_state(saved_state);
751                    let pass_ret_ty = match pass_ret_ty {
752                        Ok(pass_ret_ty) => self.symbols.get_type(&pass_ret_ty).unwrap_or(pass_ret_ty),
753                        Err(err) => {
754                            log::error!("infer_fn {} failed: {:?}", name, err);
755                            let should_remove = self
756                                .fns
757                                .get_mut(&id)
758                                .map(|fns| {
759                                    fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || !matches!(item.2, FnInferRet::Pending(_)));
760                                    fns.is_empty()
761                                })
762                                .unwrap_or(false);
763                            if should_remove {
764                                self.fns.remove(&id);
765                            }
766                            return Err(err);
767                        }
768                    };
769                    if !pass_ret_ty.is_any() {
770                        self.update_pending_return_seed(&pass_ret_ty);
771                        ret_ty = Some(pass_ret_ty.clone());
772                    } else if ret_ty.is_none() {
773                        ret_ty = Some(pass_ret_ty);
774                    }
775                    local_type_hints = pass_local_type_hints;
776                    let after_seed = self.pending_return_seed(id, generic_args, &fn_tys);
777                    if before_seed == after_seed {
778                        break;
779                    }
780                }
781                let ret_ty = ret_ty.unwrap_or(Type::Any);
782                self.fns.get_mut(&id).map(|f| {
783                    f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = FnInferRet::Done(ret_ty.clone()));
784                });
785                self.set_local_type_hints(id, generic_args, &fn_tys, local_type_hints);
786                if generic_args.is_empty()
787                    && let Some((_, Symbol::Fn { ty: Type::Fn { ret, .. }, .. })) = self.symbols.get_symbol_mut(id)
788                    && ret.is_any()
789                {
790                    *ret = std::rc::Rc::new(ret_ty.clone());
791                }
792                Ok(ret_ty)
793            } else {
794                Ok(Type::Any)
795            }
796        } else if let Symbol::Native(f) = s {
797            if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
798        } else if matches!(s, Symbol::Null) {
799            Ok(Type::Any)
800        } else {
801            Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
802        }
803    }
804
805    pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
806        match &stmt.kind {
807            StmtKind::Expr(expr, close) => {
808                if !close {
809                    self.infer_expr(expr)
810                } else {
811                    self.infer_expr(expr)?;
812                    Ok(Type::Void)
813                }
814            }
815            StmtKind::Return(expr) => {
816                if let Some(e) = expr {
817                    self.infer_expr(e)
818                } else {
819                    Ok(Type::Void)
820                }
821            }
822            StmtKind::Block(stmts) => {
823                for (idx, stmt) in stmts.iter().enumerate() {
824                    let ty = self.infer_stmt(stmt)?;
825                    if stmt.is_return() || idx == stmts.len() - 1 {
826                        return Ok(ty);
827                    }
828                }
829                Ok(Type::Void)
830            }
831            StmtKind::If { then_body, else_body, .. } => {
832                let then_ty = self.infer_stmt(then_body)?;
833                if let Some(e) = else_body {
834                    let else_ty = self.infer_stmt(e)?;
835                    if then_ty != else_ty {
836                        log::debug!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
837                        return Self::merge_return_type(stmt.span, Some(then_ty), else_ty);
838                    }
839                }
840                if else_body.is_none() {
841                    return Ok(Type::Void);
842                }
843                Ok(then_ty)
844            }
845            StmtKind::While { cond, body } => {
846                let cond_ty = self.infer_expr(cond)?;
847                if cond_ty != Type::Bool {
848                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
849                }
850                self.infer_stmt(body)
851            }
852            StmtKind::For { pat, range, body } => {
853                let ty = self.for_pattern_ty(range)?;
854                self.add_pattern_bindings_for_infer(pat, ty)?;
855                self.infer_stmt(body)
856            }
857            StmtKind::Let { pat, value } => {
858                let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
859                self.add_pattern_bindings_for_infer(pat, expr_ty)?;
860                Ok(Type::Void)
861            }
862            _ => Ok(Type::Void),
863        }
864    }
865}