Skip to main content

compiler/
infer.rs

1use super::{Compiler, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, PatternKind, Span, Stmt, StmtKind};
5
6impl Compiler {
7    pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
8        match &expr.kind {
9            ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
10            ExprKind::Value(v) => Ok(v.get_type()),
11            ExprKind::Var(idx) => {
12                let idx = self.top() + (*idx as usize);
13                if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
14            }
15            ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
16                Symbol::Const { ty, .. } => Ok(ty.clone()),
17                Symbol::Static { ty, .. } => Ok(ty.clone()),
18                Symbol::Struct(ty, _) => Ok(ty.clone()),
19                Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
20                Symbol::Native(ty) => Ok(ty.clone()),
21                s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
22            },
23            ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
24            ExprKind::Unary { value, .. } => self.infer_expr(value.as_ref()),
25            ExprKind::Binary { left, op, right } => {
26                let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
27                let ty = if op.is_logic() {
28                    let left_ty = self.infer_expr(left)?;
29                    if matches!(op, BinaryOp::And | BinaryOp::Or) && left_ty.is_any() { Type::Any } else { Type::Bool }
30                } else if op == &BinaryOp::Idx {
31                    let left_ty = self.infer_expr(left)?;
32                    if let Type::Array(elem_ty, _) = left_ty {
33                        (*elem_ty).clone()
34                    } else if let Type::Vec(elem_ty, _) = left_ty {
35                        (*elem_ty).clone()
36                    } else {
37                        let left_ty = self.symbols.get_type(&left_ty)?;
38                        let right_ty = if right.is_value() || right.is_const() {
39                            let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
40                            if right_value.is_str() {
41                                if left_ty.is_any() {
42                                    return Ok(Type::Any);
43                                }
44                                if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
45                                    return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
46                                }
47                            } else if let Type::Struct { fields, .. } = &left_ty
48                                && let Some(idx) = right_value.as_int()
49                            {
50                                return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
51                            }
52                            right_value.get_type()
53                        } else {
54                            self.infer_expr(right)?
55                        };
56                        if right_ty.is_int() || right_ty.is_uint() {
57                            if left_ty.is_any() {
58                                return Ok(Type::Any);
59                            }
60                            let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
61                            let fn_ty = self.symbols.get_type(&s)?;
62                            return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
63                        }
64                        if left_ty.is_any() {
65                            return Ok(Type::Any);
66                        }
67                        Type::Any
68                    }
69                } else {
70                    let right_ty = self.infer_expr(right)?;
71                    if op == &BinaryOp::Assign { right_ty } else { self.infer_expr(left)? + right_ty }
72                };
73                assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
74                Ok(ty)
75            }
76            ExprKind::Call { obj, params } => {
77                if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
78                    let mut args = Vec::new();
79                    for p in params {
80                        args.push(self.infer_expr(p)?);
81                    }
82                    self.infer_fn_with_params(*id, &args, generic_args)
83                } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
84                    let base_name = match ty {
85                        Type::Ident { name, .. } => name.clone(),
86                        Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
87                        _ => return Ok(Type::Any),
88                    };
89                    let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
90                    let mut args = vec![self.infer_expr(target)?];
91                    for p in params {
92                        args.push(self.infer_expr(p)?);
93                    }
94                    self.infer_fn(id, &args)
95                } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
96                    let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
97                    for p in params {
98                        args.push(self.infer_expr(p)?);
99                    }
100                    self.infer_fn(*id, &args)
101                } else if obj.is_idx() {
102                    let (target, _, method) = obj.clone().binary().unwrap();
103                    let ty = self.infer_expr(&target)?;
104                    if let Some(method) = self.get_value(&method) {
105                        let method = method.as_str();
106                        let fn_ty = match self.get_field(&ty, method) {
107                            Ok((_, fn_ty)) => fn_ty,
108                            Err(_) => {
109                                let id = self.symbols.get_id(method)?;
110                                if self.symbols.get_symbol(id)?.1.is_fn() {
111                                    Type::Symbol { id, params: Vec::new() }
112                                } else {
113                                    return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
114                                }
115                            }
116                        };
117                        if let Type::Symbol { id, .. } = fn_ty {
118                            let mut args = vec![ty];
119                            for p in params {
120                                args.push(self.infer_expr(p)?);
121                            }
122                            self.infer_fn(id, &args)
123                        } else {
124                            Ok(fn_ty)
125                        }
126                    } else {
127                        Ok(Type::Any)
128                    }
129                } else if let ExprKind::Var(idx) = &obj.kind {
130                    let idx = self.top() + (*idx as usize);
131                    if idx < self.tys.len()
132                        && let Type::Symbol { id, .. } = self.tys[idx]
133                    {
134                        let mut args = Vec::new();
135                        for p in params {
136                            args.push(self.infer_expr(p)?);
137                        }
138                        self.infer_fn(id, &args)
139                    } else {
140                        Ok(Type::Any)
141                    }
142                } else if obj.is_value() {
143                    Ok(Type::Void)
144                } else {
145                    Ok(Type::Any)
146                }
147            }
148            ExprKind::Typed { ty, .. } => Ok(ty.clone()),
149            ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
150            ExprKind::Range { start, stop, .. } => {
151                let start_ty = self.infer_expr(start)?;
152                let stop_ty = self.infer_expr(stop)?;
153                Ok(if start_ty.is_any() {
154                    stop_ty
155                } else if stop_ty.is_any() {
156                    start_ty
157                } else {
158                    stop_ty
159                })
160            }
161            _ => Ok(Type::Any),
162        }
163    }
164
165    fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
166        let mut fn_tys = Vec::new();
167        for (i, ty) in tys.iter().enumerate() {
168            if !ty.is_any() {
169                fn_tys.push(ty.clone());
170            } else if let Some(arg_ty) = arg_tys.get(i) {
171                fn_tys.push(self.symbols.get_type(arg_ty)?);
172            } else {
173                fn_tys.push(Type::Any);
174            }
175        }
176        Ok(fn_tys)
177    }
178
179    pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
180        self.infer_fn_with_params(id, arg_tys, &[])
181    }
182
183    pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
184        let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
185        if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
186            if let Type::Fn { tys, ret: _ } = ty {
187                let inferred_generic_args = if generic_args.is_empty() { crate::infer_generic_args_from_types(&generic_params, &tys, arg_tys) } else { generic_args.to_vec() };
188                let generic_args = if generic_params.is_empty() { &[] } else { inferred_generic_args.as_slice() };
189                let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
190                let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
191                let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
192                let body = if generic_params.is_empty() {
193                    body
194                } else {
195                    let mut compile_tys = tys.clone();
196                    let mut compile_cap = cap.clone();
197                    let saved_state = self.take_local_state();
198                    let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
199                    self.restore_local_state(saved_state);
200                    Stmt::new(StmtKind::Block(compiled?), Span::default())
201                };
202                if let Some(fns) = self.fns.get_mut(&id) {
203                    for f in fns.iter() {
204                        if f.0 == generic_args && f.1 == fn_tys {
205                            return Ok(f.2.clone());
206                        }
207                    }
208                    fns.push((generic_args.to_vec(), fn_tys.clone(), Type::Any));
209                } else {
210                    self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), Type::Any)]);
211                }
212                let top = self.tys.len();
213                self.tys.append(&mut fn_tys.clone());
214                for c in cap.vars.iter() {
215                    self.tys.push(self.tys[self.top() + *c].clone());
216                }
217                self.frames.push(top);
218                let ret_ty = self.infer_stmt(&body);
219                if let Some(top) = self.frames.pop() {
220                    self.tys.truncate(top);
221                }
222                let ret_ty = match ret_ty {
223                    Ok(ret_ty) => ret_ty,
224                    Err(err) => {
225                        log::error!("infer_fn {} failed: {:?}", name, err);
226                        let should_remove = self
227                            .fns
228                            .get_mut(&id)
229                            .map(|fns| {
230                                fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || item.2 != Type::Any);
231                                fns.is_empty()
232                            })
233                            .unwrap_or(false);
234                        if should_remove {
235                            self.fns.remove(&id);
236                        }
237                        return Err(err);
238                    }
239                };
240                self.fns.get_mut(&id).map(|f| {
241                    f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = ret_ty.clone());
242                });
243                Ok(ret_ty)
244            } else {
245                Ok(Type::Any)
246            }
247        } else if let Symbol::Native(f) = s {
248            if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
249        } else if matches!(s, Symbol::Null) {
250            Ok(Type::Any)
251        } else {
252            Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
253        }
254    }
255
256    pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
257        match &stmt.kind {
258            StmtKind::Expr(expr, close) => {
259                if !close {
260                    self.infer_expr(expr)
261                } else {
262                    self.infer_expr(expr)?;
263                    Ok(Type::Void)
264                }
265            }
266            StmtKind::Return(expr) => {
267                if let Some(e) = expr {
268                    self.infer_expr(e)
269                } else {
270                    Ok(Type::Void)
271                }
272            }
273            StmtKind::Block(stmts) => {
274                for (idx, stmt) in stmts.iter().enumerate() {
275                    let ty = self.infer_stmt(stmt)?;
276                    if stmt.is_return() || idx == stmts.len() - 1 {
277                        return Ok(ty);
278                    }
279                }
280                Ok(Type::Void)
281            }
282            StmtKind::If { then_body, else_body, .. } => {
283                let then_ty = self.infer_stmt(then_body)?;
284                if let Some(e) = else_body {
285                    let else_ty = self.infer_stmt(e)?;
286                    if then_ty != else_ty {
287                        log::info!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
288                        return Ok(if then_ty.is_any() { else_ty } else { then_ty });
289                    }
290                }
291                if else_body.is_none() {
292                    return Ok(Type::Void);
293                }
294                Ok(then_ty)
295            }
296            StmtKind::While { cond, body } => {
297                let cond_ty = self.infer_expr(cond)?;
298                if cond_ty != Type::Bool {
299                    return Err(Self::semantic_error(cond.span, "条件表达式必须是布尔类型"));
300                }
301                self.infer_stmt(body)
302            }
303            StmtKind::For { pat, range, body } => {
304                if let PatternKind::Var { idx, .. } = &pat.kind {
305                    let ty = self.infer_expr(range)?;
306                    self.set_ty(*idx, ty);
307                } else if let PatternKind::Tuple(pats) = &pat.kind {
308                    let ty = self.infer_expr(range)?;
309                    assert!(ty.is_any());
310                    for pat in pats {
311                        if let Some(idx) = pat.var() {
312                            self.set_ty(idx, Type::Any);
313                        }
314                    }
315                }
316                self.infer_stmt(body)
317            }
318            StmtKind::Let { pat, value } => {
319                let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
320                if let PatternKind::Ident { ty, .. } = &pat.kind {
321                    let annotated_ty = self.symbols.get_type(ty)?;
322                    if annotated_ty.is_any() {
323                        self.add_ty(expr_ty);
324                    } else {
325                        self.add_ty(annotated_ty);
326                    }
327                } else if let PatternKind::Var { idx, .. } = &pat.kind {
328                    self.set_ty(*idx, expr_ty);
329                } else if matches!(pat.kind, PatternKind::Wildcard) {
330                    self.add_ty(expr_ty);
331                }
332                Ok(Type::Void)
333            }
334            _ => Ok(Type::Void),
335        }
336    }
337}