Skip to main content

compiler/
infer.rs

1use super::{Compiler, FnInferRet, ListElemState, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, Pattern, PatternKind, Span, Stmt, StmtKind, UnaryOp};
5
6#[derive(Clone)]
7struct ReturnInfo {
8    ty: Type,
9    shape: Option<Type>,
10}
11
12impl Compiler {
13    fn current_infer_key(&self) -> Option<(u32, Vec<Type>, Vec<Type>)> {
14        self.infer_stack.last().cloned()
15    }
16
17    fn pending_return_seed(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Option<Type> {
18        self.fns.get(&id).and_then(|fns| {
19            fns.iter().find_map(|item| {
20                if item.0 == generic_args
21                    && item.1 == fn_tys
22                    && let FnInferRet::Pending(seed) = &item.2
23                {
24                    seed.clone()
25                } else {
26                    None
27                }
28            })
29        })
30    }
31
32    fn update_pending_return_seed(&mut self, ty: &Type) {
33        if ty.is_any() {
34            return;
35        }
36        let Some((id, generic_args, fn_tys)) = self.current_infer_key() else {
37            return;
38        };
39        let Some(fns) = self.fns.get_mut(&id) else {
40            return;
41        };
42        if let Some(item) = fns.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys)
43            && let FnInferRet::Pending(seed) = &mut item.2
44        {
45            let next = seed.take().map(|prev| prev + ty.clone()).unwrap_or_else(|| ty.clone());
46            *seed = Some(next);
47        }
48    }
49
50    fn add_pattern_bindings_for_infer(&mut self, pat: &Pattern, expr_ty: Type) -> Result<()> {
51        match &pat.kind {
52            PatternKind::Ident { name, ty } => {
53                let annotated_ty = self.symbols.get_type(ty)?;
54                self.add_name(name.clone());
55                self.add_ty(if annotated_ty.is_any() { expr_ty } else { annotated_ty });
56            }
57            PatternKind::Var { idx, .. } => self.set_ty(*idx, expr_ty),
58            PatternKind::Tuple(pats) => {
59                if let Type::Tuple(tys) = expr_ty {
60                    for (pat, ty) in pats.iter().zip(tys) {
61                        self.add_pattern_bindings_for_infer(pat, ty)?;
62                    }
63                } else {
64                    for pat in pats {
65                        self.add_pattern_bindings_for_infer(pat, Type::Any)?;
66                    }
67                }
68            }
69            PatternKind::List { elems, .. } => {
70                for pat in elems {
71                    self.add_pattern_bindings_for_infer(pat, Type::Any)?;
72                }
73            }
74            PatternKind::Wildcard => {
75                self.add_name("".into());
76                self.add_ty(expr_ty);
77            }
78            PatternKind::Literal(_) | PatternKind::Member(_, _) | PatternKind::Idx(_, _) => {}
79        }
80        Ok(())
81    }
82
83    fn for_pattern_ty(&mut self, range: &Expr) -> Result<Type> {
84        if matches!(range.kind, ExprKind::Range { .. }) {
85            return self.infer_range_expr(range);
86        }
87        Ok(match self.infer_expr(range)? {
88            Type::Array(elem_ty, _) | Type::Vec(elem_ty, _) | Type::List(elem_ty) => elem_ty.as_ref().clone(),
89            _ => Type::Any,
90        })
91    }
92
93    fn infer_range_expr(&mut self, range: &Expr) -> Result<Type> {
94        let ExprKind::Range { start, stop, .. } = &range.kind else {
95            return self.infer_expr(range);
96        };
97        let start_ty = self.infer_expr(start)?;
98        let stop_ty = self.infer_expr(stop)?;
99        Ok(Self::merge_range_bound_types(start_ty, stop_ty))
100    }
101
102    fn merge_range_bound_types(start_ty: Type, stop_ty: Type) -> Type {
103        if start_ty.is_any() {
104            stop_ty
105        } else if stop_ty.is_any() {
106            start_ty
107        } else if start_ty == Type::I32 && stop_ty.is_uint() {
108            stop_ty
109        } else if stop_ty == Type::I32 && start_ty.is_uint() {
110            start_ty
111        } else {
112            start_ty + stop_ty
113        }
114    }
115
116    fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
117        match left {
118            Some(left) if left == right => Ok(left),
119            Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
120            Some(left) => Ok(left + right),
121            None => Ok(right),
122        }
123    }
124
125    fn return_shape(&self, expr: &Expr, ty: &Type) -> Option<Type> {
126        if !ty.is_any() {
127            return match ty {
128                Type::Struct { .. } => Some(ty.clone()),
129                Type::Map => Some(Type::Map),
130                Type::List(elem) | Type::Array(elem, _) => Some(Type::List(elem.clone())),
131                _ => None,
132            };
133        }
134        match &expr.kind {
135            ExprKind::List(_) | ExprKind::Tuple(_) => Some(Type::list_any()),
136            ExprKind::Dict(_) => Some(Type::Map),
137            ExprKind::Value(value) => Self::dynamic_return_shape(value.get_type()),
138            ExprKind::Const(idx) => self.consts.get(*idx).and_then(|value| Self::dynamic_return_shape(value.get_type())),
139            ExprKind::Typed { ty, .. } => Some(ty.clone()),
140            _ => None,
141        }
142    }
143
144    fn dynamic_return_shape(ty: Type) -> Option<Type> {
145        match ty {
146            Type::Map => Some(Type::Map),
147            Type::List(elem) => Some(Type::List(elem)),
148            Type::Array(elem, _) => Some(Type::List(elem)),
149            _ => None,
150        }
151    }
152
153    fn local_var_idx_for_expr(&self, expr: &Expr) -> Option<u32> {
154        match &expr.kind {
155            ExprKind::Var(idx) => Some(*idx),
156            ExprKind::Ident(name) => (self.top()..self.names.len()).rev().find(|idx| self.names[*idx].eq(name)).map(|idx| (idx - self.top()) as u32),
157            _ => None,
158        }
159    }
160
161    fn infer_list_method(&mut self, target: &Expr, elem_ty: &Type, method: &str, params: &[Expr]) -> Result<Option<Type>> {
162        match method {
163            "get_idx" | "pop" => Ok(Some(match self.local_var_idx_for_expr(target).and_then(|idx| self.list_elem_state(idx)) {
164                Some(ListElemState::Known(ty)) => ty,
165                Some(ListElemState::Unknown | ListElemState::Mixed) => Type::Any,
166                None => elem_ty.clone(),
167            })),
168            "push" => {
169                let pushed_ty = params
170                    .first()
171                    .map(|param| {
172                        if let Some(value) = self.get_value(param)
173                            && (value.is_str() || value.is_native())
174                        {
175                            Ok(value.get_type())
176                        } else {
177                            self.infer_expr(param)
178                        }
179                    })
180                    .transpose()?
181                    .unwrap_or(Type::Any);
182                if let Some(idx) = self.local_var_idx_for_expr(target) {
183                    let state = self.list_elem_state(idx).unwrap_or_else(|| if elem_ty.is_any() { ListElemState::Unknown } else { ListElemState::Known(elem_ty.clone()) });
184                    let next_state = match state {
185                        ListElemState::Unknown if pushed_ty.is_any() => ListElemState::Mixed,
186                        ListElemState::Unknown => ListElemState::Known(pushed_ty),
187                        ListElemState::Known(_) if pushed_ty.is_any() => ListElemState::Mixed,
188                        ListElemState::Known(prev) => {
189                            let merged = if prev == pushed_ty {
190                                prev
191                            } else if (prev.is_int() || prev.is_uint() || prev.is_float()) && (pushed_ty.is_int() || pushed_ty.is_uint() || pushed_ty.is_float()) {
192                                prev + pushed_ty
193                            } else {
194                                Type::Any
195                            };
196                            if merged.is_any() { ListElemState::Mixed } else { ListElemState::Known(merged) }
197                        }
198                        ListElemState::Mixed => ListElemState::Mixed,
199                    };
200                    let next_elem = if let ListElemState::Known(ty) = &next_state { ty.clone() } else { Type::Any };
201                    self.set_ty(idx, Type::List(std::rc::Rc::new(next_elem)));
202                    self.set_list_elem_state(idx, Some(next_state));
203                }
204                Ok(Some(Type::Void))
205            }
206            "len" => Ok(Some(Type::I32)),
207            "is_list" | "is_null" => Ok(Some(Type::Bool)),
208            _ => Ok(None),
209        }
210    }
211
212    fn infer_return_expr(&mut self, expr: &Expr) -> Result<ReturnInfo> {
213        let ty = self.infer_expr(expr)?;
214        let shape = self.return_shape(expr, &ty);
215        let ty = if matches!(shape, Some(Type::Map | Type::List(_))) { Type::Any } else { ty };
216        Ok(ReturnInfo { ty, shape })
217    }
218
219    fn merge_return_info(span: Span, left: Option<ReturnInfo>, right: ReturnInfo) -> Result<ReturnInfo> {
220        let Some(left) = left else {
221            return Ok(right);
222        };
223        if let (Some(left_shape), Some(right_shape)) = (&left.shape, &right.shape)
224            && left_shape != right_shape
225        {
226            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, right_shape)));
227        }
228        if let Some(left_shape) = &left.shape
229            && left_shape.is_struct()
230            && right.ty.is_any()
231            && right.shape.is_none()
232        {
233            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, Type::Any)));
234        }
235        if let Some(right_shape) = &right.shape
236            && right_shape.is_struct()
237            && left.ty.is_any()
238            && left.shape.is_none()
239        {
240            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", Type::Any, right_shape)));
241        }
242        let ty = Self::merge_return_type(span, Some(left.ty), right.ty)?;
243        Ok(ReturnInfo { ty, shape: left.shape.or(right.shape) })
244    }
245
246    fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
247        self.infer_returns(stmt, true).map(|(info, _)| info.map(|info| info.ty))
248    }
249
250    pub(crate) fn check_return_type(&mut self, stmt: &Stmt) -> Result<()> {
251        self.infer_returns(stmt, true).map(|_| ())
252    }
253
254    fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<ReturnInfo>, bool)> {
255        match &stmt.kind {
256            StmtKind::Return(Some(expr)) => Ok((Some(self.infer_return_expr(expr)?), true)),
257            StmtKind::Return(None) => Ok((Some(ReturnInfo { ty: Type::Void, shape: Some(Type::Void) }), true)),
258            StmtKind::Block(stmts) => {
259                let mut ret = None;
260                for (idx, stmt) in stmts.iter().enumerate() {
261                    let (info, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
262                    if let Some(info) = info {
263                        self.update_pending_return_seed(&info.ty);
264                        ret = Some(Self::merge_return_info(stmt.span, ret, info)?);
265                        if let Some(ret) = &ret {
266                            self.update_pending_return_seed(&ret.ty);
267                        }
268                    }
269                    if always_returns {
270                        return Ok((ret, true));
271                    }
272                }
273                Ok((ret, false))
274            }
275            StmtKind::If { cond, then_body, else_body } => {
276                let cond_ty = self.infer_expr(cond)?;
277                if cond_ty != Type::Bool {
278                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
279                }
280                let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
281                if let Some(ret) = &ret {
282                    self.update_pending_return_seed(&ret.ty);
283                }
284                let else_returns = if let Some(body) = else_body {
285                    let (else_ty, else_returns) = self.infer_returns(body, tail)?;
286                    if let Some(info) = else_ty {
287                        self.update_pending_return_seed(&info.ty);
288                        ret = Some(Self::merge_return_info(body.span, ret, info)?);
289                        if let Some(ret) = &ret {
290                            self.update_pending_return_seed(&ret.ty);
291                        }
292                    }
293                    else_returns
294                } else {
295                    false
296                };
297                Ok((ret, then_returns && else_returns))
298            }
299            StmtKind::While { cond, body } => {
300                let cond_ty = self.infer_expr(cond)?;
301                if cond_ty != Type::Bool {
302                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
303                }
304                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
305            }
306            StmtKind::Loop(body) => self.infer_returns(body, false),
307            StmtKind::For { pat, range, body } => {
308                let ty = self.for_pattern_ty(range)?;
309                self.add_pattern_bindings_for_infer(pat, ty)?;
310                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
311            }
312            StmtKind::Let { .. } => {
313                self.infer_stmt(stmt)?;
314                Ok((None, false))
315            }
316            StmtKind::Expr(expr, close) => {
317                let info = self.infer_return_expr(expr)?;
318                Ok(if *close || !tail { (None, false) } else { (Some(info), true) })
319            }
320            _ => {
321                self.infer_stmt(stmt)?;
322                Ok((None, false))
323            }
324        }
325    }
326
327    pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
328        match &expr.kind {
329            ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
330            ExprKind::Value(v) if v.is_list() => Ok(v.get_type()),
331            ExprKind::Value(v) if v.is_map() => Ok(Type::Any),
332            ExprKind::Value(v) => Ok(v.get_type()),
333            ExprKind::Const(idx) => Ok(if self.consts.get(*idx).is_some_and(|value| value.is_list() && value.len() == 0) { Type::list_any() } else { Type::Any }),
334            ExprKind::Var(idx) => {
335                let idx = self.top() + (*idx as usize);
336                if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
337            }
338            ExprKind::Ident(ident) => {
339                for idx in (self.top()..self.names.len()).rev() {
340                    if self.names[idx].eq(ident) && idx < self.tys.len() {
341                        return self.symbols.get_type(&self.tys[idx]);
342                    }
343                }
344                let id = self.symbols.get_id(ident).map_err(|_| Self::semantic_error(expr.span, format!("未找到标识符 {}", ident)))?;
345                match self.symbols.get_symbol(id)?.1 {
346                    Symbol::Const { ty, .. } => Ok(ty.clone()),
347                    Symbol::Static { ty, .. } => Ok(ty.clone()),
348                    Symbol::Struct(ty, _) => Ok(ty.clone()),
349                    Symbol::Fn { .. } => Ok(Type::Symbol { id, params: Vec::new() }),
350                    Symbol::Native(ty) => Ok(ty.clone()),
351                    s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
352                }
353            }
354            ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
355                Symbol::Const { ty, .. } => Ok(ty.clone()),
356                Symbol::Static { ty, .. } => Ok(ty.clone()),
357                Symbol::Struct(ty, _) => Ok(ty.clone()),
358                Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
359                Symbol::Native(ty) => Ok(ty.clone()),
360                s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
361            },
362            ExprKind::Generic { obj, params } => {
363                let params = params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect();
364                match self.infer_expr(obj)? {
365                    Type::Symbol { id, .. } => Ok(Type::Symbol { id, params }),
366                    _ => Ok(Type::Any),
367                }
368            }
369            ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
370            ExprKind::Unary { op, value } => match op {
371                UnaryOp::Not => {
372                    let ty = self.infer_expr(value.as_ref())?;
373                    if ty.is_int() || ty.is_uint() { Ok(ty) } else { Ok(Type::Bool) }
374                }
375                UnaryOp::Neg => self.infer_expr(value.as_ref()),
376                UnaryOp::Unknow => Ok(Type::Any),
377            },
378            ExprKind::Binary { left, op, right } => {
379                if op == &BinaryOp::Assign
380                    && let ExprKind::Tuple(left_items) | ExprKind::List(left_items) = &left.kind
381                {
382                    if let ExprKind::Tuple(right_items) | ExprKind::List(right_items) = &right.kind {
383                        if left_items.len() != right_items.len() {
384                            return Err(Self::semantic_error(expr.span, format!("多重赋值数量不匹配: 左侧 {} 个,右侧 {} 个", left_items.len(), right_items.len())));
385                        }
386                        for item in right_items {
387                            let _ = self.infer_expr(item)?;
388                        }
389                    } else {
390                        let _ = self.infer_expr(right)?;
391                    }
392                    return Ok(Type::Void);
393                }
394                let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
395                let ty = if op.is_logic() {
396                    Type::Bool
397                } else if op == &BinaryOp::Idx {
398                    let left_ty = self.infer_expr(left)?;
399                    if let Type::Array(elem_ty, _) = left_ty {
400                        (*elem_ty).clone()
401                    } else if let Type::Vec(elem_ty, _) = left_ty {
402                        (*elem_ty).clone()
403                    } else if let Type::List(elem_ty) = left_ty {
404                        (*elem_ty).clone()
405                    } else {
406                        let left_ty = self.symbols.get_type(&left_ty)?;
407                        let right_ty = if right.is_value() || right.is_const() {
408                            let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
409                            if right_value.is_str() {
410                                if left_ty.is_any() {
411                                    return Ok(Type::Any);
412                                }
413                                if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
414                                    return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
415                                }
416                            } else if let Type::Struct { fields, .. } = &left_ty
417                                && let Some(idx) = right_value.as_int()
418                            {
419                                return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
420                            }
421                            right_value.get_type()
422                        } else {
423                            self.infer_expr(right)?
424                        };
425                        if right_ty.is_int() || right_ty.is_uint() {
426                            if left_ty.is_any() {
427                                return Ok(Type::Any);
428                            }
429                            let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
430                            let fn_ty = self.symbols.get_type(&s)?;
431                            return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
432                        }
433                        if left_ty.is_any() {
434                            return Ok(Type::Any);
435                        }
436                        Type::Any
437                    }
438                } else {
439                    let left_ty = self.infer_expr(left)?;
440                    let right_ty = self.infer_expr(right)?;
441                    if op == &BinaryOp::Assign {
442                        if !left_ty.is_any() && right_ty.is_any() { left_ty } else { right_ty }
443                    } else if op.is_assign() && !left_ty.is_any() && right_ty.is_any() {
444                        left_ty
445                    } else {
446                        left_ty + right_ty
447                    }
448                };
449                assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
450                Ok(ty)
451            }
452            ExprKind::Call { obj, params } => {
453                if let ExprKind::Assoc { ty, name } = &obj.kind {
454                    let base_name = match ty {
455                        Type::Ident { name, .. } => name.clone(),
456                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
457                        _ => return Ok(Type::Any),
458                    };
459                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
460                    let generic_args = match ty {
461                        Type::Ident { params, .. } | Type::Symbol { params, .. } => params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>(),
462                        _ => Vec::new(),
463                    };
464                    let mut args = Vec::new();
465                    for p in params {
466                        args.push(self.infer_expr(p)?);
467                    }
468                    self.infer_fn_with_params(id, &args, &generic_args)
469                } else if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
470                    let mut args = Vec::new();
471                    for p in params {
472                        args.push(self.infer_expr(p)?);
473                    }
474                    self.infer_fn_with_params(*id, &args, generic_args)
475                } else if let ExprKind::Generic { obj, params: generic_args } = &obj.kind {
476                    let Type::Symbol { id, .. } = self.infer_expr(obj)? else {
477                        return Ok(Type::Any);
478                    };
479                    let generic_args = generic_args.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>();
480                    let mut args = Vec::new();
481                    for p in params {
482                        args.push(self.infer_expr(p)?);
483                    }
484                    self.infer_fn_with_params(id, &args, &generic_args)
485                } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
486                    let base_name = match ty {
487                        Type::Ident { name, .. } => name.clone(),
488                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
489                        _ => return Ok(Type::Any),
490                    };
491                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
492                    let mut args = vec![self.infer_expr(target)?];
493                    for p in params {
494                        args.push(self.infer_expr(p)?);
495                    }
496                    self.infer_fn(id, &args)
497                } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
498                    let method = self.symbols.get_symbol(*id).ok().and_then(|(name, _)| name.rsplit_once("::").map(|(_, method)| method.to_string()));
499                    if let Some(target) = obj_expr
500                        && let Some(method) = method
501                    {
502                        let target_ty = self.infer_expr(target)?;
503                        if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &target_ty
504                            && let Some(ret_ty) = self.infer_list_method(target, elem_ty, method.as_str(), params)?
505                        {
506                            return Ok(ret_ty);
507                        }
508                    }
509                    let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
510                    for p in params {
511                        args.push(self.infer_expr(p)?);
512                    }
513                    self.infer_fn(*id, &args)
514                } else if let ExprKind::Ident(name) = &obj.kind {
515                    for idx in (self.top()..self.names.len()).rev() {
516                        if self.names[idx].eq(name) && idx < self.tys.len() {
517                            return if let Type::Symbol { id, .. } = &self.tys[idx] {
518                                let id = *id;
519                                let mut args = Vec::new();
520                                for p in params {
521                                    args.push(self.infer_expr(p)?);
522                                }
523                                self.infer_fn(id, &args)
524                            } else {
525                                Ok(Type::Any)
526                            };
527                        }
528                    }
529                    let Ok(id) = self.symbols.get_id(name) else {
530                        return Ok(Type::Any);
531                    };
532                    if !self.symbols.get_symbol(id)?.1.is_fn() {
533                        return Err(Self::semantic_error(obj.span, format!("符号 {} 不是函数", name)));
534                    }
535                    let mut args = Vec::new();
536                    for p in params {
537                        args.push(self.infer_expr(p)?);
538                    }
539                    self.infer_fn(id, &args)
540                } else if obj.is_idx() {
541                    let (target, _, method) = obj.clone().binary().unwrap();
542                    let ty = self.infer_expr(&target)?;
543                    if let Some(method) = self.get_value(&method) {
544                        let method = method.as_str();
545                        if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &ty
546                            && let Some(ret_ty) = self.infer_list_method(&target, elem_ty, method, params)?
547                        {
548                            return Ok(ret_ty);
549                        }
550                        let fn_ty = match self.get_field(&ty, method) {
551                            Ok((_, fn_ty)) => fn_ty,
552                            Err(_) => {
553                                let id = self.symbols.get_id(method)?;
554                                if self.symbols.get_symbol(id)?.1.is_fn() {
555                                    Type::Symbol { id, params: Vec::new() }
556                                } else {
557                                    return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
558                                }
559                            }
560                        };
561                        if let Type::Symbol { id, .. } = fn_ty {
562                            let mut args = vec![ty];
563                            for p in params {
564                                args.push(self.infer_expr(p)?);
565                            }
566                            self.infer_fn(id, &args)
567                        } else {
568                            Ok(fn_ty)
569                        }
570                    } else {
571                        Ok(Type::Any)
572                    }
573                } else if let ExprKind::Var(idx) = &obj.kind {
574                    let idx = self.top() + (*idx as usize);
575                    if idx < self.tys.len()
576                        && let Type::Symbol { id, .. } = self.tys[idx]
577                    {
578                        let mut args = Vec::new();
579                        for p in params {
580                            args.push(self.infer_expr(p)?);
581                        }
582                        self.infer_fn(id, &args)
583                    } else {
584                        Ok(Type::Any)
585                    }
586                } else if obj.is_value() {
587                    Ok(Type::Void)
588                } else {
589                    Ok(Type::Any)
590                }
591            }
592            ExprKind::Typed { ty, .. } => self.symbols.get_type(ty),
593            ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
594            ExprKind::Repeat { value, len } => {
595                let value_ty = self.infer_expr(value)?;
596                let len = self.symbols.get_type(len).unwrap_or_else(|_| len.clone());
597                if let Type::ConstInt(len) = len {
598                    let len = u32::try_from(len).map_err(|_| Self::semantic_error(expr.span, "重复数组长度必须是非负 u32"))?;
599                    Ok(Type::Array(std::rc::Rc::new(value_ty), len))
600                } else {
601                    Ok(Type::ArrayParam(std::rc::Rc::new(value_ty), std::rc::Rc::new(len)))
602                }
603            }
604            ExprKind::List(items) => {
605                if items.is_empty() {
606                    return Ok(Type::list_any());
607                }
608                let mut elem_ty = Type::Any;
609                for item in items {
610                    let item_ty = self.infer_expr(item)?;
611                    elem_ty = if elem_ty.is_any() { item_ty } else { elem_ty + item_ty };
612                }
613                Ok(Type::Array(std::rc::Rc::new(elem_ty), items.len() as u32))
614            }
615            ExprKind::Range { start, stop, .. } => {
616                let start_ty = self.infer_expr(start)?;
617                let stop_ty = self.infer_expr(stop)?;
618                Ok(Self::merge_range_bound_types(start_ty, stop_ty))
619            }
620            _ => Ok(Type::Any),
621        }
622    }
623
624    fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
625        let mut fn_tys = Vec::new();
626        for (i, ty) in tys.iter().enumerate() {
627            if !ty.is_any() {
628                fn_tys.push(ty.clone());
629            } else if let Some(arg_ty) = arg_tys.get(i) {
630                fn_tys.push(self.symbols.get_type(arg_ty)?);
631            } else {
632                fn_tys.push(Type::Any);
633            }
634        }
635        Ok(fn_tys)
636    }
637
638    fn is_optimizable_local_ty(ty: &Type) -> bool {
639        ty.is_bool() || ty.is_native()
640    }
641
642    fn is_optimizable_list_elem_ty(ty: &Type) -> bool {
643        matches!(ty, Type::Bool | Type::U8 | Type::I8 | Type::U16 | Type::I16 | Type::U32 | Type::I32 | Type::F32 | Type::U64 | Type::I64 | Type::F64 | Type::Str)
644    }
645
646    fn local_type_hint_at(&self, pos: usize) -> Option<Type> {
647        let ty = self.tys.get(pos)?;
648        match ty {
649            Type::List(_) => self.list_elem_states.get(pos).cloned().flatten().and_then(|state| {
650                if let ListElemState::Known(elem_ty) = state
651                    && Self::is_optimizable_list_elem_ty(&elem_ty)
652                {
653                    Some(Type::List(std::rc::Rc::new(elem_ty)))
654                } else {
655                    None
656                }
657            }),
658            ty if Self::is_optimizable_local_ty(ty) => Some(ty.clone()),
659            _ => None,
660        }
661    }
662
663    fn collect_local_type_hints(&self) -> Vec<Option<Type>> {
664        (self.top()..self.tys.len()).map(|pos| self.local_type_hint_at(pos)).collect()
665    }
666
667    fn set_local_type_hints(&mut self, id: u32, generic_args: &[Type], fn_tys: &[Type], hints: Vec<Option<Type>>) {
668        let items = self.local_type_hints.entry(id).or_default();
669        if let Some(item) = items.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys) {
670            item.2 = hints;
671        } else {
672            items.push((generic_args.to_vec(), fn_tys.to_vec(), hints));
673        }
674    }
675
676    pub fn inferred_local_type_hints(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Vec<Option<Type>> {
677        self.local_type_hints.get(&id).and_then(|items| items.iter().find(|item| item.0 == generic_args && item.1 == fn_tys)).map(|item| item.2.clone()).unwrap_or_default()
678    }
679
680    pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
681        self.infer_fn_with_params(id, arg_tys, &[])
682    }
683
684    pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
685        let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
686        if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
687            if let Type::Fn { tys, ret: _ } = ty {
688                let resolved_generic_args = crate::resolve_generic_args_from_types(&generic_params, &tys, arg_tys, generic_args)?;
689                let generic_args = resolved_generic_args.as_slice();
690                let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
691                let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
692                let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
693                let body = if generic_params.is_empty() {
694                    body
695                } else {
696                    let mut compile_tys = tys.clone();
697                    let mut compile_cap = cap.clone();
698                    let saved_state = self.take_local_state();
699                    if let Some((module, _)) = name.split_once("::") {
700                        self.symbols.push_module_scope(module.into());
701                    }
702                    let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
703                    if name.contains("::") {
704                        self.symbols.pop_module_scope();
705                    }
706                    self.restore_local_state(saved_state);
707                    Stmt::new(StmtKind::Block(compiled?), Span::default())
708                };
709                if let Some(fns) = self.fns.get_mut(&id) {
710                    for f in fns.iter() {
711                        if f.0 == generic_args && f.1 == fn_tys {
712                            return match &f.2 {
713                                FnInferRet::Done(ret_ty) => self.symbols.get_type(ret_ty),
714                                FnInferRet::Pending(seed) => seed.as_ref().map(|ty| self.symbols.get_type(ty)).unwrap_or(Ok(Type::Any)),
715                            };
716                        }
717                    }
718                    fns.push((generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None)));
719                } else {
720                    self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None))]);
721                }
722                let mut ret_ty = None;
723                let mut local_type_hints = Vec::new();
724                for _ in 0..4 {
725                    let before_seed = self.pending_return_seed(id, generic_args, &fn_tys);
726                    let saved_state = self.take_local_state();
727                    self.frames.push(0);
728                    for (arg, ty) in args.iter().zip(fn_tys.iter()) {
729                        self.add_name(arg.clone());
730                        self.add_ty(ty.clone());
731                    }
732                    for c in cap.vars.iter() {
733                        if let Some((name, ty)) = cap.names.get(*c) {
734                            self.add_name(name.clone());
735                            self.add_ty(ty.clone());
736                        } else {
737                            self.add_name("".into());
738                            self.add_ty(Type::Any);
739                        }
740                    }
741                    self.infer_stack.push((id, generic_args.to_vec(), fn_tys.clone()));
742                    let pass_ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
743                    self.infer_stack.pop();
744                    let pass_local_type_hints = self.collect_local_type_hints();
745                    self.restore_local_state(saved_state);
746                    let pass_ret_ty = match pass_ret_ty {
747                        Ok(pass_ret_ty) => self.symbols.get_type(&pass_ret_ty).unwrap_or(pass_ret_ty),
748                        Err(err) => {
749                            log::error!("infer_fn {} failed: {:?}", name, err);
750                            let should_remove = self
751                                .fns
752                                .get_mut(&id)
753                                .map(|fns| {
754                                    fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || !matches!(item.2, FnInferRet::Pending(_)));
755                                    fns.is_empty()
756                                })
757                                .unwrap_or(false);
758                            if should_remove {
759                                self.fns.remove(&id);
760                            }
761                            return Err(err);
762                        }
763                    };
764                    if !pass_ret_ty.is_any() {
765                        self.update_pending_return_seed(&pass_ret_ty);
766                        ret_ty = Some(pass_ret_ty.clone());
767                    } else if ret_ty.is_none() {
768                        ret_ty = Some(pass_ret_ty);
769                    }
770                    local_type_hints = pass_local_type_hints;
771                    let after_seed = self.pending_return_seed(id, generic_args, &fn_tys);
772                    if before_seed == after_seed {
773                        break;
774                    }
775                }
776                let ret_ty = ret_ty.unwrap_or(Type::Any);
777                self.fns.get_mut(&id).map(|f| {
778                    f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = FnInferRet::Done(ret_ty.clone()));
779                });
780                self.set_local_type_hints(id, generic_args, &fn_tys, local_type_hints);
781                if generic_args.is_empty()
782                    && let Some((_, Symbol::Fn { ty: Type::Fn { ret, .. }, .. })) = self.symbols.get_symbol_mut(id)
783                    && ret.is_any()
784                {
785                    *ret = std::rc::Rc::new(ret_ty.clone());
786                }
787                Ok(ret_ty)
788            } else {
789                Ok(Type::Any)
790            }
791        } else if let Symbol::Native(f) = s {
792            if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
793        } else if matches!(s, Symbol::Null) {
794            Ok(Type::Any)
795        } else {
796            Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
797        }
798    }
799
800    pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
801        match &stmt.kind {
802            StmtKind::Expr(expr, close) => {
803                if !close {
804                    self.infer_expr(expr)
805                } else {
806                    self.infer_expr(expr)?;
807                    Ok(Type::Void)
808                }
809            }
810            StmtKind::Return(expr) => {
811                if let Some(e) = expr {
812                    self.infer_expr(e)
813                } else {
814                    Ok(Type::Void)
815                }
816            }
817            StmtKind::Block(stmts) => {
818                for (idx, stmt) in stmts.iter().enumerate() {
819                    let ty = self.infer_stmt(stmt)?;
820                    if stmt.is_return() || idx == stmts.len() - 1 {
821                        return Ok(ty);
822                    }
823                }
824                Ok(Type::Void)
825            }
826            StmtKind::If { then_body, else_body, .. } => {
827                let then_ty = self.infer_stmt(then_body)?;
828                if let Some(e) = else_body {
829                    let else_ty = self.infer_stmt(e)?;
830                    if then_ty != else_ty {
831                        log::info!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
832                        return Ok(if then_ty.is_any() { else_ty } else { then_ty });
833                    }
834                }
835                if else_body.is_none() {
836                    return Ok(Type::Void);
837                }
838                Ok(then_ty)
839            }
840            StmtKind::While { cond, body } => {
841                let cond_ty = self.infer_expr(cond)?;
842                if cond_ty != Type::Bool {
843                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
844                }
845                self.infer_stmt(body)
846            }
847            StmtKind::For { pat, range, body } => {
848                let ty = self.for_pattern_ty(range)?;
849                self.add_pattern_bindings_for_infer(pat, ty)?;
850                self.infer_stmt(body)
851            }
852            StmtKind::Let { pat, value } => {
853                let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
854                self.add_pattern_bindings_for_infer(pat, expr_ty)?;
855                Ok(Type::Void)
856            }
857            _ => Ok(Type::Void),
858        }
859    }
860}