1use super::{Compiler, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, PatternKind, Span, Stmt, StmtKind};
5
6impl Compiler {
7 pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
8 match &expr.kind {
9 ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
10 ExprKind::Value(v) => Ok(v.get_type()),
11 ExprKind::Var(idx) => {
12 let idx = self.top() + (*idx as usize);
13 if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
14 }
15 ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
16 Symbol::Const { ty, .. } => Ok(ty.clone()),
17 Symbol::Static { ty, .. } => Ok(ty.clone()),
18 Symbol::Struct(ty, _) => Ok(ty.clone()),
19 Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
20 Symbol::Native(ty) => Ok(ty.clone()),
21 s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
22 },
23 ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
24 ExprKind::Unary { value, .. } => self.infer_expr(value.as_ref()),
25 ExprKind::Binary { left, op, right } => {
26 let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
27 let ty = if op.is_logic() {
28 let left_ty = self.infer_expr(left)?;
29 if matches!(op, BinaryOp::And | BinaryOp::Or) && left_ty.is_any() { Type::Any } else { Type::Bool }
30 } else if op == &BinaryOp::Idx {
31 let left_ty = self.infer_expr(left)?;
32 if let Type::Array(elem_ty, _) = left_ty {
33 (*elem_ty).clone()
34 } else if let Type::Vec(elem_ty, _) = left_ty {
35 (*elem_ty).clone()
36 } else {
37 let left_ty = self.symbols.get_type(&left_ty)?;
38 let right_ty = if right.is_value() || right.is_const() {
39 let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
40 if right_value.is_str() {
41 if left_ty.is_any() {
42 return Ok(Type::Any);
43 }
44 if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
45 return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
46 }
47 } else if let Type::Struct { fields, .. } = &left_ty
48 && let Some(idx) = right_value.as_int()
49 {
50 return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
51 }
52 right_value.get_type()
53 } else {
54 self.infer_expr(right)?
55 };
56 if right_ty.is_int() || right_ty.is_uint() {
57 if left_ty.is_any() {
58 return Ok(Type::Any);
59 }
60 let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
61 let fn_ty = self.symbols.get_type(&s)?;
62 return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
63 }
64 if left_ty.is_any() {
65 return Ok(Type::Any);
66 }
67 Type::Any
68 }
69 } else {
70 let right_ty = self.infer_expr(right)?;
71 if op == &BinaryOp::Assign { right_ty } else { self.infer_expr(left)? + right_ty }
72 };
73 assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
74 Ok(ty)
75 }
76 ExprKind::Call { obj, params } => {
77 if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
78 let mut args = Vec::new();
79 for p in params {
80 args.push(self.infer_expr(p)?);
81 }
82 self.infer_fn_with_params(*id, &args, generic_args)
83 } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
84 let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
85 for p in params {
86 args.push(self.infer_expr(p)?);
87 }
88 self.infer_fn(*id, &args)
89 } else if obj.is_idx() {
90 let (target, _, method) = obj.clone().binary().unwrap();
91 let ty = self.infer_expr(&target)?;
92 if let Some(method) = self.get_value(&method) {
93 let method = method.as_str();
94 let fn_ty = match self.get_field(&ty, method) {
95 Ok((_, fn_ty)) => fn_ty,
96 Err(_) => {
97 let id = self.symbols.get_id(method)?;
98 if self.symbols.get_symbol(id)?.1.is_fn() {
99 Type::Symbol { id, params: Vec::new() }
100 } else {
101 return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
102 }
103 }
104 };
105 if let Type::Symbol { id, .. } = fn_ty {
106 let mut args = vec![ty];
107 for p in params {
108 args.push(self.infer_expr(p)?);
109 }
110 self.infer_fn(id, &args)
111 } else {
112 Ok(fn_ty)
113 }
114 } else {
115 Ok(Type::Any)
116 }
117 } else if let ExprKind::Var(idx) = &obj.kind {
118 let idx = self.top() + (*idx as usize);
119 if idx < self.tys.len()
120 && let Type::Symbol { id, .. } = self.tys[idx]
121 {
122 let mut args = Vec::new();
123 for p in params {
124 args.push(self.infer_expr(p)?);
125 }
126 self.infer_fn(id, &args)
127 } else {
128 Ok(Type::Any)
129 }
130 } else if obj.is_value() {
131 Ok(Type::Void)
132 } else {
133 Ok(Type::Any)
134 }
135 }
136 ExprKind::Typed { ty, .. } => Ok(ty.clone()),
137 ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
138 ExprKind::Range { start, stop, .. } => {
139 let start_ty = self.infer_expr(start)?;
140 let stop_ty = self.infer_expr(stop)?;
141 Ok(if start_ty.is_any() {
142 stop_ty
143 } else if stop_ty.is_any() {
144 start_ty
145 } else {
146 stop_ty
147 })
148 }
149 _ => Ok(Type::Any),
150 }
151 }
152
153 fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
154 let mut fn_tys = Vec::new();
155 for (i, ty) in tys.iter().enumerate() {
156 if !ty.is_any() {
157 fn_tys.push(ty.clone());
158 } else if let Some(arg_ty) = arg_tys.get(i) {
159 fn_tys.push(self.symbols.get_type(arg_ty)?);
160 } else {
161 fn_tys.push(Type::Any);
162 }
163 }
164 Ok(fn_tys)
165 }
166
167 pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
168 self.infer_fn_with_params(id, arg_tys, &[])
169 }
170
171 pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
172 let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
173 if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
174 if let Type::Fn { tys, ret: _ } = ty {
175 let inferred_generic_args = if generic_args.is_empty() { crate::infer_generic_args_from_types(&generic_params, &tys, arg_tys) } else { generic_args.to_vec() };
176 let generic_args = if generic_params.is_empty() { &[] } else { inferred_generic_args.as_slice() };
177 let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
178 let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
179 let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
180 let body = if generic_params.is_empty() {
181 body
182 } else {
183 let mut compile_tys = tys.clone();
184 let mut compile_cap = cap.clone();
185 let saved_state = self.take_local_state();
186 let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
187 self.restore_local_state(saved_state);
188 Stmt::new(StmtKind::Block(compiled?), Span::default())
189 };
190 if let Some(fns) = self.fns.get_mut(&id) {
191 for f in fns.iter() {
192 if f.0 == generic_args && f.1 == fn_tys {
193 return Ok(f.2.clone());
194 }
195 }
196 fns.push((generic_args.to_vec(), fn_tys.clone(), Type::Any));
197 } else {
198 self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), Type::Any)]);
199 }
200 let top = self.tys.len();
201 self.tys.append(&mut fn_tys.clone());
202 for c in cap.vars.iter() {
203 self.tys.push(self.tys[self.top() + *c].clone());
204 }
205 self.frames.push(top);
206 let ret_ty = self.infer_stmt(&body);
207 if let Some(top) = self.frames.pop() {
208 self.tys.truncate(top);
209 }
210 let ret_ty = match ret_ty {
211 Ok(ret_ty) => ret_ty,
212 Err(err) => {
213 log::error!("infer_fn {} failed: {:?}", name, err);
214 let should_remove = self
215 .fns
216 .get_mut(&id)
217 .map(|fns| {
218 fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || item.2 != Type::Any);
219 fns.is_empty()
220 })
221 .unwrap_or(false);
222 if should_remove {
223 self.fns.remove(&id);
224 }
225 return Err(err);
226 }
227 };
228 self.fns.get_mut(&id).map(|f| {
229 f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = ret_ty.clone());
230 });
231 Ok(ret_ty)
232 } else {
233 Ok(Type::Any)
234 }
235 } else if let Symbol::Native(f) = s {
236 if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
237 } else if matches!(s, Symbol::Null) {
238 Ok(Type::Any)
239 } else {
240 Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
241 }
242 }
243
244 pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
245 match &stmt.kind {
246 StmtKind::Expr(expr, close) => {
247 if !close {
248 self.infer_expr(expr)
249 } else {
250 self.infer_expr(expr)?;
251 Ok(Type::Void)
252 }
253 }
254 StmtKind::Return(expr) => {
255 if let Some(e) = expr {
256 self.infer_expr(e)
257 } else {
258 Ok(Type::Void)
259 }
260 }
261 StmtKind::Block(stmts) => {
262 for (idx, stmt) in stmts.iter().enumerate() {
263 let ty = self.infer_stmt(stmt)?;
264 if stmt.is_return() || idx == stmts.len() - 1 {
265 return Ok(ty);
266 }
267 }
268 Ok(Type::Void)
269 }
270 StmtKind::If { then_body, else_body, .. } => {
271 let then_ty = self.infer_stmt(then_body)?;
272 if let Some(e) = else_body {
273 let else_ty = self.infer_stmt(e)?;
274 if then_ty != else_ty {
275 log::info!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
276 return Ok(if then_ty.is_any() { else_ty } else { then_ty });
277 }
278 }
279 if else_body.is_none() {
280 return Ok(Type::Void);
281 }
282 Ok(then_ty)
283 }
284 StmtKind::While { cond, body } => {
285 let cond_ty = self.infer_expr(cond)?;
286 if cond_ty != Type::Bool {
287 return Err(Self::semantic_error(cond.span, "条件表达式必须是布尔类型"));
288 }
289 self.infer_stmt(body)
290 }
291 StmtKind::For { pat, range, body } => {
292 if let PatternKind::Var { idx, .. } = &pat.kind {
293 let ty = self.infer_expr(range)?;
294 self.set_ty(*idx, ty);
295 } else if let PatternKind::Tuple(pats) = &pat.kind {
296 let ty = self.infer_expr(range)?;
297 assert!(ty.is_any());
298 for pat in pats {
299 if let Some(idx) = pat.var() {
300 self.set_ty(idx, Type::Any);
301 }
302 }
303 }
304 self.infer_stmt(body)
305 }
306 StmtKind::Let { pat, value } => {
307 let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
308 if let PatternKind::Ident { ty, .. } = &pat.kind {
309 let annotated_ty = self.symbols.get_type(ty)?;
310 if annotated_ty.is_any() {
311 self.add_ty(expr_ty);
312 } else {
313 self.add_ty(annotated_ty);
314 }
315 } else if let PatternKind::Var { idx, .. } = &pat.kind {
316 self.set_ty(*idx, expr_ty);
317 } else if matches!(pat.kind, PatternKind::Wildcard) {
318 self.add_ty(expr_ty);
319 }
320 Ok(Type::Void)
321 }
322 _ => Ok(Type::Void),
323 }
324 }
325}