Skip to main content

compiler/
infer.rs

1use super::{Compiler, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, PatternKind, Span, Stmt, StmtKind, UnaryOp};
5
6#[derive(Clone)]
7struct ReturnInfo {
8    ty: Type,
9    shape: Option<Type>,
10}
11
12impl Compiler {
13    fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
14        match left {
15            Some(left) if left == right => Ok(left),
16            Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
17            Some(left) => Ok(left + right),
18            None => Ok(right),
19        }
20    }
21
22    fn return_shape(&self, expr: &Expr, ty: &Type) -> Option<Type> {
23        if !ty.is_any() {
24            return if ty.is_struct() { Some(ty.clone()) } else { None };
25        }
26        match &expr.kind {
27            ExprKind::List(_) | ExprKind::Tuple(_) => Some(Type::List),
28            ExprKind::Dict(_) => Some(Type::Map),
29            ExprKind::Value(value) => Self::dynamic_return_shape(value.get_type()),
30            ExprKind::Const(idx) => self.consts.get(*idx).and_then(|value| Self::dynamic_return_shape(value.get_type())),
31            ExprKind::Typed { ty, .. } => Some(ty.clone()),
32            _ => None,
33        }
34    }
35
36    fn dynamic_return_shape(ty: Type) -> Option<Type> {
37        match ty {
38            Type::Map => Some(Type::Map),
39            Type::List | Type::Array(_, _) => Some(Type::List),
40            _ => None,
41        }
42    }
43
44    fn infer_return_expr(&mut self, expr: &Expr) -> Result<ReturnInfo> {
45        let ty = self.infer_expr(expr)?;
46        let shape = self.return_shape(expr, &ty);
47        Ok(ReturnInfo { ty, shape })
48    }
49
50    fn merge_return_info(span: Span, left: Option<ReturnInfo>, right: ReturnInfo) -> Result<ReturnInfo> {
51        let Some(left) = left else {
52            return Ok(right);
53        };
54        if let (Some(left_shape), Some(right_shape)) = (&left.shape, &right.shape)
55            && left_shape != right_shape
56        {
57            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, right_shape)));
58        }
59        if let Some(left_shape) = &left.shape
60            && left_shape.is_struct()
61            && right.ty.is_any()
62            && right.shape.is_none()
63        {
64            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, Type::Any)));
65        }
66        if let Some(right_shape) = &right.shape
67            && right_shape.is_struct()
68            && left.ty.is_any()
69            && left.shape.is_none()
70        {
71            return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", Type::Any, right_shape)));
72        }
73        let ty = Self::merge_return_type(span, Some(left.ty), right.ty)?;
74        Ok(ReturnInfo { ty, shape: left.shape.or(right.shape) })
75    }
76
77    fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
78        self.infer_returns(stmt, true).map(|(info, _)| info.map(|info| info.ty))
79    }
80
81    pub(crate) fn check_return_type(&mut self, stmt: &Stmt) -> Result<()> {
82        self.infer_returns(stmt, true).map(|_| ())
83    }
84
85    fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<ReturnInfo>, bool)> {
86        match &stmt.kind {
87            StmtKind::Return(Some(expr)) => Ok((Some(self.infer_return_expr(expr)?), true)),
88            StmtKind::Return(None) => Ok((Some(ReturnInfo { ty: Type::Void, shape: Some(Type::Void) }), true)),
89            StmtKind::Block(stmts) => {
90                let mut ret = None;
91                for (idx, stmt) in stmts.iter().enumerate() {
92                    let (info, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
93                    if let Some(info) = info {
94                        ret = Some(Self::merge_return_info(stmt.span, ret, info)?);
95                    }
96                    if always_returns {
97                        return Ok((ret, true));
98                    }
99                }
100                Ok((ret, false))
101            }
102            StmtKind::If { cond, then_body, else_body } => {
103                let cond_ty = self.infer_expr(cond)?;
104                if cond_ty != Type::Bool {
105                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
106                }
107                let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
108                let else_returns = if let Some(body) = else_body {
109                    let (else_ty, else_returns) = self.infer_returns(body, tail)?;
110                    if let Some(info) = else_ty {
111                        ret = Some(Self::merge_return_info(body.span, ret, info)?);
112                    }
113                    else_returns
114                } else {
115                    false
116                };
117                Ok((ret, then_returns && else_returns))
118            }
119            StmtKind::While { cond, body } => {
120                let cond_ty = self.infer_expr(cond)?;
121                if cond_ty != Type::Bool {
122                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
123                }
124                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
125            }
126            StmtKind::Loop(body) => self.infer_returns(body, false),
127            StmtKind::For { pat, range, body } => {
128                if let PatternKind::Var { idx, .. } = &pat.kind {
129                    let ty = self.infer_expr(range)?;
130                    self.set_ty(*idx, ty);
131                } else if let PatternKind::Tuple(pats) = &pat.kind {
132                    let ty = self.infer_expr(range)?;
133                    assert!(ty.is_any());
134                    for pat in pats {
135                        if let Some(idx) = pat.var() {
136                            self.set_ty(idx, Type::Any);
137                        }
138                    }
139                }
140                self.infer_returns(body, false).map(|(ty, _)| (ty, false))
141            }
142            StmtKind::Let { .. } => {
143                self.infer_stmt(stmt)?;
144                Ok((None, false))
145            }
146            StmtKind::Expr(expr, close) => {
147                let info = self.infer_return_expr(expr)?;
148                Ok(if *close || !tail { (None, false) } else { (Some(info), true) })
149            }
150            _ => {
151                self.infer_stmt(stmt)?;
152                Ok((None, false))
153            }
154        }
155    }
156
157    pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
158        match &expr.kind {
159            ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
160            ExprKind::Value(v) if v.is_list() || v.is_map() => Ok(Type::Any),
161            ExprKind::Value(v) => Ok(v.get_type()),
162            ExprKind::Const(_) => Ok(Type::Any),
163            ExprKind::Var(idx) => {
164                let idx = self.top() + (*idx as usize);
165                if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
166            }
167            ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
168                Symbol::Const { ty, .. } => Ok(ty.clone()),
169                Symbol::Static { ty, .. } => Ok(ty.clone()),
170                Symbol::Struct(ty, _) => Ok(ty.clone()),
171                Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
172                Symbol::Native(ty) => Ok(ty.clone()),
173                s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
174            },
175            ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
176            ExprKind::Unary { op, value } => match op {
177                UnaryOp::Not => {
178                    self.infer_expr(value.as_ref())?;
179                    Ok(Type::Bool)
180                }
181                UnaryOp::Neg => self.infer_expr(value.as_ref()),
182                UnaryOp::Unknow => Ok(Type::Any),
183            },
184            ExprKind::Binary { left, op, right } => {
185                let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
186                let ty = if op.is_logic() {
187                    let left_ty = self.infer_expr(left)?;
188                    if matches!(op, BinaryOp::And | BinaryOp::Or) && left_ty.is_any() { Type::Any } else { Type::Bool }
189                } else if op == &BinaryOp::Idx {
190                    let left_ty = self.infer_expr(left)?;
191                    if let Type::Array(elem_ty, _) = left_ty {
192                        (*elem_ty).clone()
193                    } else if let Type::Vec(elem_ty, _) = left_ty {
194                        (*elem_ty).clone()
195                    } else {
196                        let left_ty = self.symbols.get_type(&left_ty)?;
197                        let right_ty = if right.is_value() || right.is_const() {
198                            let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
199                            if right_value.is_str() {
200                                if left_ty.is_any() {
201                                    return Ok(Type::Any);
202                                }
203                                if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
204                                    return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
205                                }
206                            } else if let Type::Struct { fields, .. } = &left_ty
207                                && let Some(idx) = right_value.as_int()
208                            {
209                                return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
210                            }
211                            right_value.get_type()
212                        } else {
213                            self.infer_expr(right)?
214                        };
215                        if right_ty.is_int() || right_ty.is_uint() {
216                            if left_ty.is_any() {
217                                return Ok(Type::Any);
218                            }
219                            let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
220                            let fn_ty = self.symbols.get_type(&s)?;
221                            return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
222                        }
223                        if left_ty.is_any() {
224                            return Ok(Type::Any);
225                        }
226                        Type::Any
227                    }
228                } else {
229                    let right_ty = self.infer_expr(right)?;
230                    if op == &BinaryOp::Assign { right_ty } else { self.infer_expr(left)? + right_ty }
231                };
232                assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
233                Ok(ty)
234            }
235            ExprKind::Call { obj, params } => {
236                if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
237                    let mut args = Vec::new();
238                    for p in params {
239                        args.push(self.infer_expr(p)?);
240                    }
241                    self.infer_fn_with_params(*id, &args, generic_args)
242                } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
243                    let base_name = match ty {
244                        Type::Ident { name, .. } => name.clone(),
245                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
246                        _ => return Ok(Type::Any),
247                    };
248                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
249                    let mut args = vec![self.infer_expr(target)?];
250                    for p in params {
251                        args.push(self.infer_expr(p)?);
252                    }
253                    self.infer_fn(id, &args)
254                } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
255                    let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
256                    for p in params {
257                        args.push(self.infer_expr(p)?);
258                    }
259                    self.infer_fn(*id, &args)
260                } else if obj.is_idx() {
261                    let (target, _, method) = obj.clone().binary().unwrap();
262                    let ty = self.infer_expr(&target)?;
263                    if let Some(method) = self.get_value(&method) {
264                        let method = method.as_str();
265                        let fn_ty = match self.get_field(&ty, method) {
266                            Ok((_, fn_ty)) => fn_ty,
267                            Err(_) => {
268                                let id = self.symbols.get_id(method)?;
269                                if self.symbols.get_symbol(id)?.1.is_fn() {
270                                    Type::Symbol { id, params: Vec::new() }
271                                } else {
272                                    return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
273                                }
274                            }
275                        };
276                        if let Type::Symbol { id, .. } = fn_ty {
277                            let mut args = vec![ty];
278                            for p in params {
279                                args.push(self.infer_expr(p)?);
280                            }
281                            self.infer_fn(id, &args)
282                        } else {
283                            Ok(fn_ty)
284                        }
285                    } else {
286                        Ok(Type::Any)
287                    }
288                } else if let ExprKind::Var(idx) = &obj.kind {
289                    let idx = self.top() + (*idx as usize);
290                    if idx < self.tys.len()
291                        && let Type::Symbol { id, .. } = self.tys[idx]
292                    {
293                        let mut args = Vec::new();
294                        for p in params {
295                            args.push(self.infer_expr(p)?);
296                        }
297                        self.infer_fn(id, &args)
298                    } else {
299                        Ok(Type::Any)
300                    }
301                } else if obj.is_value() {
302                    Ok(Type::Void)
303                } else {
304                    Ok(Type::Any)
305                }
306            }
307            ExprKind::Typed { ty, .. } => self.symbols.get_type(ty),
308            ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
309            ExprKind::Range { start, stop, .. } => {
310                let start_ty = self.infer_expr(start)?;
311                let stop_ty = self.infer_expr(stop)?;
312                Ok(if start_ty.is_any() {
313                    stop_ty
314                } else if stop_ty.is_any() {
315                    start_ty
316                } else {
317                    stop_ty
318                })
319            }
320            _ => Ok(Type::Any),
321        }
322    }
323
324    fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
325        let mut fn_tys = Vec::new();
326        for (i, ty) in tys.iter().enumerate() {
327            if !ty.is_any() {
328                fn_tys.push(ty.clone());
329            } else if let Some(arg_ty) = arg_tys.get(i) {
330                fn_tys.push(self.symbols.get_type(arg_ty)?);
331            } else {
332                fn_tys.push(Type::Any);
333            }
334        }
335        Ok(fn_tys)
336    }
337
338    pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
339        self.infer_fn_with_params(id, arg_tys, &[])
340    }
341
342    pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
343        let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
344        if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
345            if let Type::Fn { tys, ret: _ } = ty {
346                let inferred_generic_args = if generic_args.is_empty() { crate::infer_generic_args_from_types(&generic_params, &tys, arg_tys) } else { generic_args.to_vec() };
347                let generic_args = if generic_params.is_empty() { &[] } else { inferred_generic_args.as_slice() };
348                let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
349                let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
350                let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
351                let body = if generic_params.is_empty() {
352                    body
353                } else {
354                    let mut compile_tys = tys.clone();
355                    let mut compile_cap = cap.clone();
356                    let saved_state = self.take_local_state();
357                    let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
358                    self.restore_local_state(saved_state);
359                    Stmt::new(StmtKind::Block(compiled?), Span::default())
360                };
361                if let Some(fns) = self.fns.get_mut(&id) {
362                    for f in fns.iter() {
363                        if f.0 == generic_args && f.1 == fn_tys {
364                            return self.symbols.get_type(&f.2);
365                        }
366                    }
367                    fns.push((generic_args.to_vec(), fn_tys.clone(), Type::Any));
368                } else {
369                    self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), Type::Any)]);
370                }
371                let top = self.tys.len();
372                self.tys.append(&mut fn_tys.clone());
373                for c in cap.vars.iter() {
374                    self.tys.push(self.tys[self.top() + *c].clone());
375                }
376                self.frames.push(top);
377                let ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
378                if let Some(top) = self.frames.pop() {
379                    self.tys.truncate(top);
380                }
381                let ret_ty = match ret_ty {
382                    Ok(ret_ty) => self.symbols.get_type(&ret_ty).unwrap_or(ret_ty),
383                    Err(err) => {
384                        log::error!("infer_fn {} failed: {:?}", name, err);
385                        let should_remove = self
386                            .fns
387                            .get_mut(&id)
388                            .map(|fns| {
389                                fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || item.2 != Type::Any);
390                                fns.is_empty()
391                            })
392                            .unwrap_or(false);
393                        if should_remove {
394                            self.fns.remove(&id);
395                        }
396                        return Err(err);
397                    }
398                };
399                self.fns.get_mut(&id).map(|f| {
400                    f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = ret_ty.clone());
401                });
402                if generic_args.is_empty()
403                    && let Some((_, Symbol::Fn { ty: Type::Fn { ret, .. }, .. })) = self.symbols.get_symbol_mut(id)
404                    && ret.is_any()
405                {
406                    *ret = std::rc::Rc::new(ret_ty.clone());
407                }
408                Ok(ret_ty)
409            } else {
410                Ok(Type::Any)
411            }
412        } else if let Symbol::Native(f) = s {
413            if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
414        } else if matches!(s, Symbol::Null) {
415            Ok(Type::Any)
416        } else {
417            Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
418        }
419    }
420
421    pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
422        match &stmt.kind {
423            StmtKind::Expr(expr, close) => {
424                if !close {
425                    self.infer_expr(expr)
426                } else {
427                    self.infer_expr(expr)?;
428                    Ok(Type::Void)
429                }
430            }
431            StmtKind::Return(expr) => {
432                if let Some(e) = expr {
433                    self.infer_expr(e)
434                } else {
435                    Ok(Type::Void)
436                }
437            }
438            StmtKind::Block(stmts) => {
439                for (idx, stmt) in stmts.iter().enumerate() {
440                    let ty = self.infer_stmt(stmt)?;
441                    if stmt.is_return() || idx == stmts.len() - 1 {
442                        return Ok(ty);
443                    }
444                }
445                Ok(Type::Void)
446            }
447            StmtKind::If { then_body, else_body, .. } => {
448                let then_ty = self.infer_stmt(then_body)?;
449                if let Some(e) = else_body {
450                    let else_ty = self.infer_stmt(e)?;
451                    if then_ty != else_ty {
452                        log::info!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
453                        return Ok(if then_ty.is_any() { else_ty } else { then_ty });
454                    }
455                }
456                if else_body.is_none() {
457                    return Ok(Type::Void);
458                }
459                Ok(then_ty)
460            }
461            StmtKind::While { cond, body } => {
462                let cond_ty = self.infer_expr(cond)?;
463                if cond_ty != Type::Bool {
464                    return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
465                }
466                self.infer_stmt(body)
467            }
468            StmtKind::For { pat, range, body } => {
469                if let PatternKind::Var { idx, .. } = &pat.kind {
470                    let ty = self.infer_expr(range)?;
471                    self.set_ty(*idx, ty);
472                } else if let PatternKind::Tuple(pats) = &pat.kind {
473                    let ty = self.infer_expr(range)?;
474                    assert!(ty.is_any());
475                    for pat in pats {
476                        if let Some(idx) = pat.var() {
477                            self.set_ty(idx, Type::Any);
478                        }
479                    }
480                }
481                self.infer_stmt(body)
482            }
483            StmtKind::Let { pat, value } => {
484                let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
485                if let PatternKind::Ident { ty, .. } = &pat.kind {
486                    let annotated_ty = self.symbols.get_type(ty)?;
487                    if annotated_ty.is_any() {
488                        self.add_ty(expr_ty);
489                    } else {
490                        self.add_ty(annotated_ty);
491                    }
492                } else if let PatternKind::Var { idx, .. } = &pat.kind {
493                    self.set_ty(*idx, expr_ty);
494                } else if matches!(pat.kind, PatternKind::Wildcard) {
495                    self.add_ty(expr_ty);
496                }
497                Ok(Type::Void)
498            }
499            _ => Ok(Type::Void),
500        }
501    }
502}