Skip to main content

compiler/
infer.rs

1use super::{Compiler, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, Pattern, PatternKind, Span, Stmt, StmtKind, UnaryOp};
5
6#[derive(Clone)]
7struct ReturnInfo {
8    ty: Type,
9    shape: Option<Type>,
10}
11
12impl Compiler {
13    fn add_pattern_bindings_for_infer(&mut self, pat: &Pattern, expr_ty: Type) -> Result<()> {
14        match &pat.kind {
15            PatternKind::Ident { name, ty } => {
16                let annotated_ty = self.symbols.get_type(ty)?;
17                self.add_name(name.clone());
18                self.add_ty(if annotated_ty.is_any() { expr_ty } else { annotated_ty });
19            }
20            PatternKind::Var { idx, .. } => self.set_ty(*idx, expr_ty),
21            PatternKind::Tuple(pats) => {
22                if let Type::Tuple(tys) = expr_ty {
23                    for (pat, ty) in pats.iter().zip(tys) {
24                        self.add_pattern_bindings_for_infer(pat, ty)?;
25                    }
26                } else {
27                    for pat in pats {
28                        self.add_pattern_bindings_for_infer(pat, Type::Any)?;
29                    }
30                }
31            }
32            PatternKind::List { elems, .. } => {
33                for pat in elems {
34                    self.add_pattern_bindings_for_infer(pat, Type::Any)?;
35                }
36            }
37            PatternKind::Wildcard => {
38                self.add_name("".into());
39                self.add_ty(expr_ty);
40            }
41            PatternKind::Literal(_) | PatternKind::Member(_, _) | PatternKind::Idx(_, _) => {}
42        }
43        Ok(())
44    }
45
46    fn for_pattern_ty(&mut self, range: &Expr) -> Result<Type> {
47        if matches!(range.kind, ExprKind::Range { .. }) {
48            return self.infer_expr(range);
49        }
50        Ok(match self.infer_expr(range)? {
51            Type::Array(elem_ty, _) | Type::Vec(elem_ty, _) => elem_ty.as_ref().clone(),
52            _ => Type::Any,
53        })
54    }
55
56    fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
57        match left {
58            Some(left) if left == right => Ok(left),
59            Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
60            Some(left) => Ok(left + right),
61            None => Ok(right),
62        }
63    }
64
65    fn return_shape(&self, expr: &Expr, ty: &Type) -> Option<Type> {
66        if !ty.is_any() {
67            return if ty.is_struct() { Some(ty.clone()) } else { None };
68        }
69        match &expr.kind {
70            ExprKind::List(_) | ExprKind::Tuple(_) => Some(Type::List),
71            ExprKind::Dict(_) => Some(Type::Map),
72            ExprKind::Value(value) => Self::dynamic_return_shape(value.get_type()),
73            ExprKind::Const(idx) => self.consts.get(*idx).and_then(|value| Self::dynamic_return_shape(value.get_type())),
74            ExprKind::Typed { ty, .. } => Some(ty.clone()),
75            _ => None,
76        }
77    }
78
79    fn dynamic_return_shape(ty: Type) -> Option<Type> {
80        match ty {
81            Type::Map => Some(Type::Map),
82            Type::List | Type::Array(_, _) => Some(Type::List),
83            _ => None,
84        }
85    }
86
87    fn infer_return_expr(&mut self, expr: &Expr) -> Result<ReturnInfo> {
88        let ty = self.infer_expr(expr)?;
89        let shape = self.return_shape(expr, &ty);
90        Ok(ReturnInfo { ty, shape })
91    }
92
93    fn merge_return_info(span: Span, left: Option<ReturnInfo>, right: ReturnInfo) -> Result<ReturnInfo> {
94        let Some(left) = left else {
95            return Ok(right);
96        };
97        if let (Some(left_shape), Some(right_shape)) = (&left.shape, &right.shape)
98            && left_shape != right_shape
99        {
100            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, right_shape)));
101        }
102        if let Some(left_shape) = &left.shape
103            && left_shape.is_struct()
104            && right.ty.is_any()
105            && right.shape.is_none()
106        {
107            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, Type::Any)));
108        }
109        if let Some(right_shape) = &right.shape
110            && right_shape.is_struct()
111            && left.ty.is_any()
112            && left.shape.is_none()
113        {
114            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", Type::Any, right_shape)));
115        }
116        let ty = Self::merge_return_type(span, Some(left.ty), right.ty)?;
117        Ok(ReturnInfo { ty, shape: left.shape.or(right.shape) })
118    }
119
120    fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
121        self.infer_returns(stmt, true).map(|(info, _)| info.map(|info| info.ty))
122    }
123
124    pub(crate) fn check_return_type(&mut self, stmt: &Stmt) -> Result<()> {
125        self.infer_returns(stmt, true).map(|_| ())
126    }
127
128    fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<ReturnInfo>, bool)> {
129        match &stmt.kind {
130            StmtKind::Return(Some(expr)) => Ok((Some(self.infer_return_expr(expr)?), true)),
131            StmtKind::Return(None) => Ok((Some(ReturnInfo { ty: Type::Void, shape: Some(Type::Void) }), true)),
132            StmtKind::Block(stmts) => {
133                let mut ret = None;
134                for (idx, stmt) in stmts.iter().enumerate() {
135                    let (info, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
136                    if let Some(info) = info {
137                        ret = Some(Self::merge_return_info(stmt.span, ret, info)?);
138                    }
139                    if always_returns {
140                        return Ok((ret, true));
141                    }
142                }
143                Ok((ret, false))
144            }
145            StmtKind::If { cond, then_body, else_body } => {
146                let cond_ty = self.infer_expr(cond)?;
147                if cond_ty != Type::Bool {
148                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
149                }
150                let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
151                let else_returns = if let Some(body) = else_body {
152                    let (else_ty, else_returns) = self.infer_returns(body, tail)?;
153                    if let Some(info) = else_ty {
154                        ret = Some(Self::merge_return_info(body.span, ret, info)?);
155                    }
156                    else_returns
157                } else {
158                    false
159                };
160                Ok((ret, then_returns && else_returns))
161            }
162            StmtKind::While { cond, body } => {
163                let cond_ty = self.infer_expr(cond)?;
164                if cond_ty != Type::Bool {
165                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
166                }
167                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
168            }
169            StmtKind::Loop(body) => self.infer_returns(body, false),
170            StmtKind::For { pat, range, body } => {
171                let ty = self.for_pattern_ty(range)?;
172                self.add_pattern_bindings_for_infer(pat, ty)?;
173                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
174            }
175            StmtKind::Let { .. } => {
176                self.infer_stmt(stmt)?;
177                Ok((None, false))
178            }
179            StmtKind::Expr(expr, close) => {
180                let info = self.infer_return_expr(expr)?;
181                Ok(if *close || !tail { (None, false) } else { (Some(info), true) })
182            }
183            _ => {
184                self.infer_stmt(stmt)?;
185                Ok((None, false))
186            }
187        }
188    }
189
190    pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
191        match &expr.kind {
192            ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
193            ExprKind::Value(v) if v.is_list() || v.is_map() => Ok(Type::Any),
194            ExprKind::Value(v) => Ok(v.get_type()),
195            ExprKind::Const(_) => Ok(Type::Any),
196            ExprKind::Var(idx) => {
197                let idx = self.top() + (*idx as usize);
198                if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
199            }
200            ExprKind::Ident(ident) => {
201                for idx in (self.top()..self.names.len()).rev() {
202                    if self.names[idx].eq(ident) && idx < self.tys.len() {
203                        return self.symbols.get_type(&self.tys[idx]);
204                    }
205                }
206                let id = self.symbols.get_id(ident).map_err(|_| Self::semantic_error(expr.span, format!("未找到标识符 {}", ident)))?;
207                match self.symbols.get_symbol(id)?.1 {
208                    Symbol::Const { ty, .. } => Ok(ty.clone()),
209                    Symbol::Static { ty, .. } => Ok(ty.clone()),
210                    Symbol::Struct(ty, _) => Ok(ty.clone()),
211                    Symbol::Fn { .. } => Ok(Type::Symbol { id, params: Vec::new() }),
212                    Symbol::Native(ty) => Ok(ty.clone()),
213                    s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
214                }
215            }
216            ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
217                Symbol::Const { ty, .. } => Ok(ty.clone()),
218                Symbol::Static { ty, .. } => Ok(ty.clone()),
219                Symbol::Struct(ty, _) => Ok(ty.clone()),
220                Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
221                Symbol::Native(ty) => Ok(ty.clone()),
222                s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
223            },
224            ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
225            ExprKind::Unary { op, value } => match op {
226                UnaryOp::Not => {
227                    let ty = self.infer_expr(value.as_ref())?;
228                    if ty.is_int() || ty.is_uint() { Ok(ty) } else { Ok(Type::Bool) }
229                }
230                UnaryOp::Neg => self.infer_expr(value.as_ref()),
231                UnaryOp::Unknow => Ok(Type::Any),
232            },
233            ExprKind::Binary { left, op, right } => {
234                let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
235                let ty = if op.is_logic() { Type::Bool } else if op == &BinaryOp::Idx {
236                    let left_ty = self.infer_expr(left)?;
237                    if let Type::Array(elem_ty, _) = left_ty {
238                        (*elem_ty).clone()
239                    } else if let Type::Vec(elem_ty, _) = left_ty {
240                        (*elem_ty).clone()
241                    } else {
242                        let left_ty = self.symbols.get_type(&left_ty)?;
243                        let right_ty = if right.is_value() || right.is_const() {
244                            let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
245                            if right_value.is_str() {
246                                if left_ty.is_any() {
247                                    return Ok(Type::Any);
248                                }
249                                if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
250                                    return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
251                                }
252                            } else if let Type::Struct { fields, .. } = &left_ty
253                                && let Some(idx) = right_value.as_int()
254                            {
255                                return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
256                            }
257                            right_value.get_type()
258                        } else {
259                            self.infer_expr(right)?
260                        };
261                        if right_ty.is_int() || right_ty.is_uint() {
262                            if left_ty.is_any() {
263                                return Ok(Type::Any);
264                            }
265                            let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
266                            let fn_ty = self.symbols.get_type(&s)?;
267                            return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
268                        }
269                        if left_ty.is_any() {
270                            return Ok(Type::Any);
271                        }
272                        Type::Any
273                    }
274                } else {
275                    let right_ty = self.infer_expr(right)?;
276                    if op == &BinaryOp::Assign { right_ty } else { self.infer_expr(left)? + right_ty }
277                };
278                assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
279                Ok(ty)
280            }
281            ExprKind::Call { obj, params } => {
282                if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
283                    let mut args = Vec::new();
284                    for p in params {
285                        args.push(self.infer_expr(p)?);
286                    }
287                    self.infer_fn_with_params(*id, &args, generic_args)
288                } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
289                    let base_name = match ty {
290                        Type::Ident { name, .. } => name.clone(),
291                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
292                        _ => return Ok(Type::Any),
293                    };
294                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
295                    let mut args = vec![self.infer_expr(target)?];
296                    for p in params {
297                        args.push(self.infer_expr(p)?);
298                    }
299                    self.infer_fn(id, &args)
300                } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
301                    let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
302                    for p in params {
303                        args.push(self.infer_expr(p)?);
304                    }
305                    self.infer_fn(*id, &args)
306                } else if obj.is_idx() {
307                    let (target, _, method) = obj.clone().binary().unwrap();
308                    let ty = self.infer_expr(&target)?;
309                    if let Some(method) = self.get_value(&method) {
310                        let method = method.as_str();
311                        let fn_ty = match self.get_field(&ty, method) {
312                            Ok((_, fn_ty)) => fn_ty,
313                            Err(_) => {
314                                let id = self.symbols.get_id(method)?;
315                                if self.symbols.get_symbol(id)?.1.is_fn() {
316                                    Type::Symbol { id, params: Vec::new() }
317                                } else {
318                                    return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
319                                }
320                            }
321                        };
322                        if let Type::Symbol { id, .. } = fn_ty {
323                            let mut args = vec![ty];
324                            for p in params {
325                                args.push(self.infer_expr(p)?);
326                            }
327                            self.infer_fn(id, &args)
328                        } else {
329                            Ok(fn_ty)
330                        }
331                    } else {
332                        Ok(Type::Any)
333                    }
334                } else if let ExprKind::Var(idx) = &obj.kind {
335                    let idx = self.top() + (*idx as usize);
336                    if idx < self.tys.len()
337                        && let Type::Symbol { id, .. } = self.tys[idx]
338                    {
339                        let mut args = Vec::new();
340                        for p in params {
341                            args.push(self.infer_expr(p)?);
342                        }
343                        self.infer_fn(id, &args)
344                    } else {
345                        Ok(Type::Any)
346                    }
347                } else if obj.is_value() {
348                    Ok(Type::Void)
349                } else {
350                    Ok(Type::Any)
351                }
352            }
353            ExprKind::Typed { ty, .. } => self.symbols.get_type(ty),
354            ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
355            ExprKind::Range { start, stop, .. } => {
356                let start_ty = self.infer_expr(start)?;
357                let stop_ty = self.infer_expr(stop)?;
358                Ok(if start_ty.is_any() {
359                    stop_ty
360                } else if stop_ty.is_any() {
361                    start_ty
362                } else {
363                    stop_ty
364                })
365            }
366            _ => Ok(Type::Any),
367        }
368    }
369
370    fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
371        let mut fn_tys = Vec::new();
372        for (i, ty) in tys.iter().enumerate() {
373            if !ty.is_any() {
374                fn_tys.push(ty.clone());
375            } else if let Some(arg_ty) = arg_tys.get(i) {
376                fn_tys.push(self.symbols.get_type(arg_ty)?);
377            } else {
378                fn_tys.push(Type::Any);
379            }
380        }
381        Ok(fn_tys)
382    }
383
384    pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
385        self.infer_fn_with_params(id, arg_tys, &[])
386    }
387
388    pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
389        let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
390        if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
391            if let Type::Fn { tys, ret: _ } = ty {
392                let inferred_generic_args = if generic_args.is_empty() { crate::infer_generic_args_from_types(&generic_params, &tys, arg_tys) } else { generic_args.to_vec() };
393                let generic_args = if generic_params.is_empty() { &[] } else { inferred_generic_args.as_slice() };
394                let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
395                let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
396                let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
397                let body = if generic_params.is_empty() {
398                    body
399                } else {
400                    let mut compile_tys = tys.clone();
401                    let mut compile_cap = cap.clone();
402                    let saved_state = self.take_local_state();
403                    let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
404                    self.restore_local_state(saved_state);
405                    Stmt::new(StmtKind::Block(compiled?), Span::default())
406                };
407                if let Some(fns) = self.fns.get_mut(&id) {
408                    for f in fns.iter() {
409                        if f.0 == generic_args && f.1 == fn_tys {
410                            return self.symbols.get_type(&f.2);
411                        }
412                    }
413                    fns.push((generic_args.to_vec(), fn_tys.clone(), Type::Any));
414                } else {
415                    self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), Type::Any)]);
416                }
417                let top = self.tys.len();
418                self.tys.append(&mut fn_tys.clone());
419                for c in cap.vars.iter() {
420                    self.tys.push(self.tys[self.top() + *c].clone());
421                }
422                self.frames.push(top);
423                let ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
424                if let Some(top) = self.frames.pop() {
425                    self.tys.truncate(top);
426                }
427                let ret_ty = match ret_ty {
428                    Ok(ret_ty) => self.symbols.get_type(&ret_ty).unwrap_or(ret_ty),
429                    Err(err) => {
430                        log::error!("infer_fn {} failed: {:?}", name, err);
431                        let should_remove = self
432                            .fns
433                            .get_mut(&id)
434                            .map(|fns| {
435                                fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || item.2 != Type::Any);
436                                fns.is_empty()
437                            })
438                            .unwrap_or(false);
439                        if should_remove {
440                            self.fns.remove(&id);
441                        }
442                        return Err(err);
443                    }
444                };
445                self.fns.get_mut(&id).map(|f| {
446                    f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = ret_ty.clone());
447                });
448                if generic_args.is_empty()
449                    && let Some((_, Symbol::Fn { ty: Type::Fn { ret, .. }, .. })) = self.symbols.get_symbol_mut(id)
450                    && ret.is_any()
451                {
452                    *ret = std::rc::Rc::new(ret_ty.clone());
453                }
454                Ok(ret_ty)
455            } else {
456                Ok(Type::Any)
457            }
458        } else if let Symbol::Native(f) = s {
459            if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
460        } else if matches!(s, Symbol::Null) {
461            Ok(Type::Any)
462        } else {
463            Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
464        }
465    }
466
467    pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
468        match &stmt.kind {
469            StmtKind::Expr(expr, close) => {
470                if !close {
471                    self.infer_expr(expr)
472                } else {
473                    self.infer_expr(expr)?;
474                    Ok(Type::Void)
475                }
476            }
477            StmtKind::Return(expr) => {
478                if let Some(e) = expr {
479                    self.infer_expr(e)
480                } else {
481                    Ok(Type::Void)
482                }
483            }
484            StmtKind::Block(stmts) => {
485                for (idx, stmt) in stmts.iter().enumerate() {
486                    let ty = self.infer_stmt(stmt)?;
487                    if stmt.is_return() || idx == stmts.len() - 1 {
488                        return Ok(ty);
489                    }
490                }
491                Ok(Type::Void)
492            }
493            StmtKind::If { then_body, else_body, .. } => {
494                let then_ty = self.infer_stmt(then_body)?;
495                if let Some(e) = else_body {
496                    let else_ty = self.infer_stmt(e)?;
497                    if then_ty != else_ty {
498                        log::info!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
499                        return Ok(if then_ty.is_any() { else_ty } else { then_ty });
500                    }
501                }
502                if else_body.is_none() {
503                    return Ok(Type::Void);
504                }
505                Ok(then_ty)
506            }
507            StmtKind::While { cond, body } => {
508                let cond_ty = self.infer_expr(cond)?;
509                if cond_ty != Type::Bool {
510                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
511                }
512                self.infer_stmt(body)
513            }
514            StmtKind::For { pat, range, body } => {
515                let ty = self.for_pattern_ty(range)?;
516                self.add_pattern_bindings_for_infer(pat, ty)?;
517                self.infer_stmt(body)
518            }
519            StmtKind::Let { pat, value } => {
520                let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
521                self.add_pattern_bindings_for_infer(pat, expr_ty)?;
522                Ok(Type::Void)
523            }
524            _ => Ok(Type::Void),
525        }
526    }
527}