1use super::{Compiler, FnInferRet, ListElemState, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, Pattern, PatternKind, Span, Stmt, StmtKind, UnaryOp};
5use smol_str::SmolStr;
6
7#[derive(Clone)]
8struct ReturnInfo {
9 ty: Type,
10 shape: Option<Type>,
11}
12
13const MAX_INFER_DEPTH: usize = 64;
18
19impl Compiler {
20 fn current_infer_key(&self) -> Option<(u32, Vec<Type>, Vec<Type>)> {
21 self.infer_stack.last().cloned()
22 }
23
24 fn pending_return_seed(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Option<Type> {
25 self.fns.get(&id).and_then(|fns| {
26 fns.iter().find_map(|item| {
27 if item.0 == generic_args
28 && item.1 == fn_tys
29 && let FnInferRet::Pending(seed) = &item.2
30 {
31 seed.clone()
32 } else {
33 None
34 }
35 })
36 })
37 }
38
39 fn update_pending_return_seed(&mut self, ty: &Type) {
40 if ty.is_any() {
41 return;
42 }
43 let Some((id, generic_args, fn_tys)) = self.current_infer_key() else {
44 return;
45 };
46 let Some(fns) = self.fns.get_mut(&id) else {
47 return;
48 };
49 if let Some(item) = fns.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys)
50 && let FnInferRet::Pending(seed) = &mut item.2
51 {
52 let next = seed.take().map(|prev| prev + ty.clone()).unwrap_or_else(|| ty.clone());
53 *seed = Some(next);
54 }
55 }
56
57 fn try_find_base_return_ty(&self, body: &Stmt) -> Option<Type> {
59 match &body.kind {
60 StmtKind::Block(stmts) => stmts.iter().find_map(|s| self.try_find_base_return_ty(s)),
61 StmtKind::If { then_body, else_body, .. } => self.try_find_base_return_ty(then_body).or_else(|| else_body.as_ref().and_then(|b| self.try_find_base_return_ty(b))),
62 StmtKind::Return(Some(expr)) => Self::try_literal_type(expr),
63 StmtKind::Expr(expr, false) => Self::try_literal_type(expr),
64 _ => None,
65 }
66 }
67
68 fn try_find_base_return_ty_with_scope(&mut self, body: &Stmt, fn_id: u32, fn_name: &str, args: &[SmolStr], fn_tys: &[Type]) -> Option<Type> {
70 let saved_state = self.take_local_state();
71 self.frames.push(0);
72 for (arg, ty) in args.iter().zip(fn_tys.iter()) {
73 self.add_name(arg.clone());
74 self.add_ty(ty.clone());
75 }
76 let result = self.try_find_base_return_ty_with_scope_inner(body, fn_id, fn_name);
77 self.restore_local_state(saved_state);
78 result
79 }
80
81 fn try_find_base_return_ty_with_scope_inner(&mut self, body: &Stmt, fn_id: u32, fn_name: &str) -> Option<Type> {
82 match &body.kind {
83 StmtKind::Block(stmts) => stmts.iter().find_map(|s| self.try_find_base_return_ty_with_scope_inner(s, fn_id, fn_name)),
84 StmtKind::If { then_body, else_body, .. } => {
85 self.try_find_base_return_ty_with_scope_inner(then_body, fn_id, fn_name).or_else(|| else_body.as_ref().and_then(|b| self.try_find_base_return_ty_with_scope_inner(b, fn_id, fn_name)))
86 }
87 StmtKind::Return(Some(expr)) => {
88 if Self::expr_calls_fn(expr, fn_id, fn_name) {
89 None
90 } else {
91 self.infer_return_expr(expr).ok().map(|info| info.ty)
92 }
93 }
94 StmtKind::Expr(expr, false) => {
95 if Self::expr_calls_fn(expr, fn_id, fn_name) {
96 None
97 } else {
98 self.infer_return_expr(expr).ok().map(|info| info.ty)
99 }
100 }
101 _ => None,
102 }
103 }
104
105 fn expr_calls_fn(expr: &Expr, fn_id: u32, fn_name: &str) -> bool {
106 match &expr.kind {
107 ExprKind::Call { obj, params } => {
108 if let ExprKind::Id(id, _) = &obj.kind {
109 return *id == fn_id;
110 }
111 if let ExprKind::Ident(name) = &obj.kind {
112 if name.as_str() == fn_name || fn_name.ends_with(&format!("::{}", name)) {
113 return true;
114 }
115 }
116 params.iter().any(|p| Self::expr_calls_fn(p, fn_id, fn_name))
117 }
118 ExprKind::Binary { left, op: _, right } => Self::expr_calls_fn(left, fn_id, fn_name) || Self::expr_calls_fn(right, fn_id, fn_name),
119 ExprKind::Unary { op: _, value } => Self::expr_calls_fn(value, fn_id, fn_name),
120 ExprKind::Typed { value, ty: _ } => Self::expr_calls_fn(value, fn_id, fn_name),
121 _ => false,
122 }
123 }
124
125 fn try_literal_type(expr: &Expr) -> Option<Type> {
126 match &expr.kind {
127 ExprKind::Value(v) => Some(v.get_type()),
128 ExprKind::Unary { op: UnaryOp::Neg, value } => Self::try_literal_type(value),
129 _ => None,
130 }
131 }
132
133 fn add_pattern_bindings_for_infer(&mut self, pat: &Pattern, expr_ty: Type) -> Result<()> {
134 match &pat.kind {
135 PatternKind::Ident { name, ty } => {
136 let annotated_ty = self.symbols.get_type(ty)?;
137 self.add_name(name.clone());
138 self.add_ty(if annotated_ty.is_any() { expr_ty } else { annotated_ty });
139 }
140 PatternKind::Var { idx, .. } => self.set_ty(*idx, expr_ty),
141 PatternKind::Tuple(pats) => {
142 if let Type::Tuple(tys) = expr_ty {
143 for (pat, ty) in pats.iter().zip(tys) {
144 self.add_pattern_bindings_for_infer(pat, ty)?;
145 }
146 } else {
147 for pat in pats {
148 self.add_pattern_bindings_for_infer(pat, Type::Any)?;
149 }
150 }
151 }
152 PatternKind::List { elems, .. } => {
153 for pat in elems {
154 self.add_pattern_bindings_for_infer(pat, Type::Any)?;
155 }
156 }
157 PatternKind::Wildcard => {
158 self.add_name("".into());
159 self.add_ty(expr_ty);
160 }
161 PatternKind::Literal(_) | PatternKind::Member(_, _) | PatternKind::Idx(_, _) | PatternKind::Struct { .. } => {}
162 }
163 Ok(())
164 }
165
166 fn for_pattern_ty(&mut self, range: &Expr) -> Result<Type> {
167 if matches!(range.kind, ExprKind::Range { .. }) {
168 return self.infer_range_expr(range);
169 }
170 Ok(match self.infer_expr(range)? {
171 Type::Array(elem_ty, _) | Type::Vec(elem_ty, _) | Type::List(elem_ty) => elem_ty.as_ref().clone(),
172 _ => Type::Any,
173 })
174 }
175
176 fn infer_range_expr(&mut self, range: &Expr) -> Result<Type> {
177 let ExprKind::Range { start, stop, .. } = &range.kind else {
178 return self.infer_expr(range);
179 };
180 let start_ty = self.infer_expr(start)?;
181 let stop_ty = self.infer_expr(stop)?;
182 Ok(Self::merge_range_bound_types(start_ty, stop_ty))
183 }
184
185 fn merge_range_bound_types(start_ty: Type, stop_ty: Type) -> Type {
186 if start_ty.is_any() {
187 stop_ty
188 } else if stop_ty.is_any() {
189 start_ty
190 } else if matches!(start_ty, Type::I32 | Type::I64) && stop_ty.is_uint() {
193 stop_ty
194 } else if matches!(stop_ty, Type::I32 | Type::I64) && start_ty.is_uint() {
195 start_ty
196 } else {
197 start_ty + stop_ty
198 }
199 }
200
201 fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
202 match left {
203 Some(left) if left == right => Ok(left),
204 Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
205 Some(left) if left.is_any() || right.is_any() => Ok(Type::Any),
206 Some(left) => Ok(left + right),
207 None => Ok(right),
208 }
209 }
210
211 fn return_shape(&self, expr: &Expr, ty: &Type) -> Option<Type> {
212 if !ty.is_any() {
213 return match ty {
214 Type::Struct { .. } => Some(ty.clone()),
215 Type::Map => Some(Type::Map),
216 Type::List(elem) | Type::Array(elem, _) => Some(Type::List(elem.clone())),
217 _ => None,
218 };
219 }
220 match &expr.kind {
221 ExprKind::List(_) | ExprKind::Tuple(_) => Some(Type::list_any()),
222 ExprKind::Dict(_) => Some(Type::Map),
223 ExprKind::Value(value) => Self::dynamic_return_shape(value.get_type()),
224 ExprKind::Const(idx) => self.consts.get_index(*idx).and_then(|(_, value)| Self::dynamic_return_shape(value.get_type())),
225 ExprKind::Typed { ty, .. } => Some(ty.clone()),
226 _ => None,
227 }
228 }
229
230 fn dynamic_return_shape(ty: Type) -> Option<Type> {
231 match ty {
232 Type::Map => Some(Type::Map),
233 Type::List(elem) => Some(Type::List(elem)),
234 Type::Array(elem, _) => Some(Type::List(elem)),
235 _ => None,
236 }
237 }
238
239 fn local_var_idx_for_expr(&self, expr: &Expr) -> Option<u32> {
240 match &expr.kind {
241 ExprKind::Var(idx) => Some(*idx),
242 ExprKind::Ident(name) => (self.top()..self.names.len()).rev().find(|idx| self.names[*idx].eq(name)).map(|idx| (idx - self.top()) as u32),
243 _ => None,
244 }
245 }
246
247 fn infer_list_method(&mut self, target: &Expr, elem_ty: &Type, method: &str, params: &[Expr]) -> Result<Option<Type>> {
248 match method {
249 "get_idx" | "pop" => Ok(Some(match self.local_var_idx_for_expr(target).and_then(|idx| self.list_elem_state(idx)) {
250 Some(ListElemState::Known(ty)) => ty,
251 Some(ListElemState::Unknown | ListElemState::Mixed) => Type::Any,
252 None => elem_ty.clone(),
253 })),
254 "push" => {
255 let pushed_ty = params
256 .first()
257 .map(|param| {
258 if let Some(value) = self.get_value(param)
259 && (value.is_str() || value.is_native())
260 {
261 Ok(value.get_type())
262 } else {
263 self.infer_expr(param)
264 }
265 })
266 .transpose()?
267 .unwrap_or(Type::Any);
268 if let Some(idx) = self.local_var_idx_for_expr(target) {
269 let state = self.list_elem_state(idx).unwrap_or_else(|| if elem_ty.is_any() { ListElemState::Unknown } else { ListElemState::Known(elem_ty.clone()) });
270 let next_state = match state {
271 ListElemState::Unknown if pushed_ty.is_any() => ListElemState::Mixed,
272 ListElemState::Unknown => ListElemState::Known(pushed_ty),
273 ListElemState::Known(_) if pushed_ty.is_any() => ListElemState::Mixed,
274 ListElemState::Known(prev) => {
275 let merged = if prev == pushed_ty {
276 prev
277 } else if (prev.is_int() || prev.is_uint() || prev.is_float()) && (pushed_ty.is_int() || pushed_ty.is_uint() || pushed_ty.is_float()) {
278 prev + pushed_ty
279 } else {
280 Type::Any
281 };
282 if merged.is_any() { ListElemState::Mixed } else { ListElemState::Known(merged) }
283 }
284 ListElemState::Mixed => ListElemState::Mixed,
285 };
286 let next_elem = if let ListElemState::Known(ty) = &next_state { ty.clone() } else { Type::Any };
287 self.set_ty(idx, Type::List(std::rc::Rc::new(next_elem)));
288 self.set_list_elem_state(idx, Some(next_state));
289 }
290 Ok(Some(Type::Void))
291 }
292 "len" => Ok(Some(Type::I32)),
293 "is_list" | "is_null" => Ok(Some(Type::Bool)),
294 _ => Ok(None),
295 }
296 }
297
298 fn infer_return_expr(&mut self, expr: &Expr) -> Result<ReturnInfo> {
299 let ty = self.infer_expr(expr)?;
300 let shape = self.return_shape(expr, &ty);
301 let ty = if matches!(shape, Some(Type::Map | Type::List(_))) { Type::Any } else { ty };
302 Ok(ReturnInfo { ty, shape })
303 }
304
305 fn merge_return_info(span: Span, left: Option<ReturnInfo>, right: ReturnInfo) -> Result<ReturnInfo> {
306 let Some(left) = left else {
307 return Ok(right);
308 };
309 if let (Some(left_shape), Some(right_shape)) = (&left.shape, &right.shape)
310 && left_shape != right_shape
311 {
312 return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, right_shape)));
313 }
314 if let Some(left_shape) = &left.shape
315 && left_shape.is_struct()
316 && right.ty.is_any()
317 && right.shape.is_none()
318 {
319 return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, Type::Any)));
320 }
321 if let Some(right_shape) = &right.shape
322 && right_shape.is_struct()
323 && left.ty.is_any()
324 && left.shape.is_none()
325 {
326 return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", Type::Any, right_shape)));
327 }
328 let ty = Self::merge_return_type(span, Some(left.ty), right.ty)?;
329 Ok(ReturnInfo { ty, shape: left.shape.or(right.shape) })
330 }
331
332 fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
333 self.infer_returns(stmt, true).map(|(info, _)| info.map(|info| info.ty))
334 }
335
336 pub(crate) fn check_return_type(&mut self, stmt: &Stmt) -> Result<()> {
337 self.infer_returns(stmt, true).map(|_| ())
338 }
339
340 fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<ReturnInfo>, bool)> {
341 match &stmt.kind {
342 StmtKind::Return(Some(expr)) => Ok((Some(self.infer_return_expr(expr)?), true)),
343 StmtKind::Return(None) => Ok((Some(ReturnInfo { ty: Type::Void, shape: Some(Type::Void) }), true)),
344 StmtKind::Block(stmts) => {
345 let mut ret = None;
346 for (idx, stmt) in stmts.iter().enumerate() {
347 let (info, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
348 if let Some(info) = info {
349 self.update_pending_return_seed(&info.ty);
350 ret = Some(Self::merge_return_info(stmt.span, ret, info)?);
351 if let Some(ret) = &ret {
352 self.update_pending_return_seed(&ret.ty);
353 }
354 }
355 if always_returns {
356 return Ok((ret, true));
357 }
358 }
359 Ok((ret, false))
360 }
361 StmtKind::If { cond, then_body, else_body } => {
362 let cond_ty = self.infer_expr(cond)?;
363 if cond_ty != Type::Bool {
364 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
365 }
366 let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
367 if let Some(ret) = &ret {
368 self.update_pending_return_seed(&ret.ty);
369 }
370 let else_returns = if let Some(body) = else_body {
371 let (else_ty, else_returns) = self.infer_returns(body, tail)?;
372 if let Some(info) = else_ty {
373 self.update_pending_return_seed(&info.ty);
374 ret = Some(Self::merge_return_info(body.span, ret, info)?);
375 if let Some(ret) = &ret {
376 self.update_pending_return_seed(&ret.ty);
377 }
378 }
379 else_returns
380 } else {
381 false
382 };
383 Ok((ret, then_returns && else_returns))
384 }
385 StmtKind::While { cond, body } => {
386 let cond_ty = self.infer_expr(cond)?;
387 if cond_ty != Type::Bool {
388 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
389 }
390 self.infer_returns(body, false).map(|(ty, _)| (ty, false))
391 }
392 StmtKind::Loop(body) => self.infer_returns(body, false),
393 StmtKind::For { pat, range, body } => {
394 let ty = self.for_pattern_ty(range)?;
395 self.add_pattern_bindings_for_infer(pat, ty)?;
396 self.infer_returns(body, false).map(|(ty, _)| (ty, false))
397 }
398 StmtKind::Let { .. } => {
399 self.infer_stmt(stmt)?;
400 Ok((None, false))
401 }
402 StmtKind::Expr(expr, close) => {
403 let info = self.infer_return_expr(expr)?;
404 Ok(if *close || !tail { (None, false) } else { (Some(info), true) })
405 }
406 _ => {
407 self.infer_stmt(stmt)?;
408 Ok((None, false))
409 }
410 }
411 }
412
413 pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
414 match &expr.kind {
415 ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
416 ExprKind::Value(v) if v.is_list() => Ok(v.get_type()),
417 ExprKind::Value(v) if v.is_map() => Ok(Type::Any),
418 ExprKind::Value(v) => Ok(v.get_type()),
419 ExprKind::Const(idx) => Ok(match self.consts.get_index(*idx) {
420 Some((_, value)) if value.is_str() => Type::Str,
421 Some((_, value)) if value.is_list() && value.len() == 0 => Type::list_any(),
422 _ => Type::Any,
423 }),
424 ExprKind::Var(idx) => {
425 let idx = self.top() + (*idx as usize);
426 if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
427 }
428 ExprKind::Ident(ident) => {
429 for idx in (self.top()..self.names.len()).rev() {
430 if self.names[idx].eq(ident) && idx < self.tys.len() {
431 return self.symbols.get_type(&self.tys[idx]);
432 }
433 }
434 let id = self.symbols.get_id(ident).map_err(|_| Self::semantic_error(expr.span, format!("未找到标识符 {}", ident)))?;
435 match self.symbols.get_symbol(id)?.1 {
436 Symbol::Const { ty, .. } => Ok(ty.clone()),
437 Symbol::Static { ty, .. } => Ok(ty.clone()),
438 Symbol::Struct(ty, _) => Ok(ty.clone()),
439 Symbol::Fn { .. } => Ok(Type::Symbol { id, params: Vec::new() }),
440 Symbol::Native(ty) => Ok(ty.clone()),
441 s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
442 }
443 }
444 ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
445 Symbol::Const { ty, .. } => Ok(ty.clone()),
446 Symbol::Static { ty, .. } => Ok(ty.clone()),
447 Symbol::Struct(ty, _) => Ok(ty.clone()),
448 Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
449 Symbol::Native(ty) => Ok(ty.clone()),
450 s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
451 },
452 ExprKind::Generic { obj, params } => {
453 let params = params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect();
454 match self.infer_expr(obj)? {
455 Type::Symbol { id, .. } => Ok(Type::Symbol { id, params }),
456 _ => Ok(Type::Any),
457 }
458 }
459 ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
460 ExprKind::Unary { op, value } => match op {
461 UnaryOp::Not => {
462 let ty = self.infer_expr(value.as_ref())?;
463 if ty.is_int() || ty.is_uint() { Ok(ty) } else { Ok(Type::Bool) }
464 }
465 UnaryOp::Neg => self.infer_expr(value.as_ref()),
466 UnaryOp::Unknow => Ok(Type::Any),
467 },
468 ExprKind::Binary { left, op, right } => {
469 if op == &BinaryOp::Assign
470 && let ExprKind::Tuple(left_items) | ExprKind::List(left_items) = &left.kind
471 {
472 if let ExprKind::Tuple(right_items) | ExprKind::List(right_items) = &right.kind {
473 if left_items.len() != right_items.len() {
474 return Err(Self::semantic_error(expr.span, format!("多重赋值数量不匹配: 左侧 {} 个,右侧 {} 个", left_items.len(), right_items.len())));
475 }
476 for item in right_items {
477 let _ = self.infer_expr(item)?;
478 }
479 } else {
480 let _ = self.infer_expr(right)?;
481 }
482 return Ok(Type::Void);
483 }
484 let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
485 let ty = if op.is_logic() {
486 Type::Bool
487 } else if op == &BinaryOp::Idx {
488 let left_ty = self.infer_expr(left)?;
489 if matches!(right.kind, ExprKind::Range { .. }) {
490 let elem_ty = match &left_ty {
493 Type::Array(e, _) | Type::Vec(e, _) | Type::List(e) => (**e).clone(),
494 _ => Type::Any,
495 };
496 return Ok(Type::List(std::rc::Rc::new(elem_ty)));
497 }
498 if let Type::Array(elem_ty, _) = left_ty {
499 (*elem_ty).clone()
500 } else if let Type::Vec(elem_ty, _) = left_ty {
501 (*elem_ty).clone()
502 } else if let Type::List(elem_ty) = left_ty {
503 (*elem_ty).clone()
504 } else {
505 let left_ty = self.symbols.get_type(&left_ty)?;
506 let right_ty = if right.is_value() || right.is_const() {
507 let right_value = if let ExprKind::Const(c) = &right.kind {
508 match self.consts.get_index(*c) {
509 Some((_, v)) => v.clone(),
510 None => right.clone().value()?,
511 }
512 } else {
513 right.clone().value()?
514 };
515 if right_value.is_str() {
516 if left_ty.is_any() {
517 return Ok(Type::Any);
518 }
519 if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
520 return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
521 }
522 } else if let Type::Struct { fields, .. } = &left_ty
523 && let Some(idx) = right_value.as_int()
524 {
525 return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
526 }
527 right_value.get_type()
528 } else {
529 self.infer_expr(right)?
530 };
531 if right_ty.is_int() || right_ty.is_uint() {
532 if left_ty.is_any() {
533 return Ok(Type::Any);
534 }
535 let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
536 let fn_ty = self.symbols.get_type(&s)?;
537 return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
538 }
539 if left_ty.is_any() {
540 return Ok(Type::Any);
541 }
542 Type::Any
543 }
544 } else {
545 let left_ty = self.infer_expr(left)?;
546 let right_ty = self.infer_expr(right)?;
547 if op == &BinaryOp::Assign {
548 if !left_ty.is_any() && right_ty.is_any() { left_ty } else { right_ty }
549 } else if op.is_assign() && !left_ty.is_any() && right_ty.is_any() {
550 left_ty
551 } else {
552 left_ty + right_ty
553 }
554 };
555 assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
556 Ok(ty)
557 }
558 ExprKind::Call { obj, params } => {
559 if let ExprKind::Assoc { ty, name } = &obj.kind {
560 let base_name = match ty {
561 Type::Ident { name, .. } => name.clone(),
562 Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
563 _ => return Ok(Type::Any),
564 };
565 let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
566 let generic_args = match ty {
567 Type::Ident { params, .. } | Type::Symbol { params, .. } => params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>(),
568 _ => Vec::new(),
569 };
570 let mut args = Vec::new();
571 for p in params {
572 args.push(self.infer_expr(p)?);
573 }
574 self.infer_fn_with_params(id, &args, &generic_args)
575 } else if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
576 let mut args = Vec::new();
577 for p in params {
578 args.push(self.infer_expr(p)?);
579 }
580 self.infer_fn_with_params(*id, &args, generic_args)
581 } else if let ExprKind::Generic { obj, params: generic_args } = &obj.kind {
582 let Type::Symbol { id, .. } = self.infer_expr(obj)? else {
583 return Ok(Type::Any);
584 };
585 let generic_args = generic_args.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>();
586 let mut args = Vec::new();
587 for p in params {
588 args.push(self.infer_expr(p)?);
589 }
590 self.infer_fn_with_params(id, &args, &generic_args)
591 } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
592 let base_name = match ty {
593 Type::Ident { name, .. } => name.clone(),
594 Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
595 _ => return Ok(Type::Any),
596 };
597 let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
598 let mut args = vec![self.infer_expr(target)?];
599 for p in params {
600 args.push(self.infer_expr(p)?);
601 }
602 self.infer_fn(id, &args)
603 } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
604 let method = self.symbols.get_symbol(*id).ok().and_then(|(name, _)| name.rsplit_once("::").map(|(_, method)| method.to_string()));
605 if let Some(target) = obj_expr
606 && let Some(method) = method
607 {
608 let target_ty = self.infer_expr(target)?;
609 if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &target_ty
610 && let Some(ret_ty) = self.infer_list_method(target, elem_ty, method.as_str(), params)?
611 {
612 return Ok(ret_ty);
613 }
614 }
615 let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
616 for p in params {
617 args.push(self.infer_expr(p)?);
618 }
619 self.infer_fn(*id, &args)
620 } else if let ExprKind::Ident(name) = &obj.kind {
621 for idx in (self.top()..self.names.len()).rev() {
622 if self.names[idx].eq(name) && idx < self.tys.len() {
623 return if let Type::Symbol { id, .. } = &self.tys[idx] {
624 let id = *id;
625 let mut args = Vec::new();
626 for p in params {
627 args.push(self.infer_expr(p)?);
628 }
629 self.infer_fn(id, &args)
630 } else {
631 Ok(Type::Any)
632 };
633 }
634 }
635 let Ok(id) = self.symbols.get_id(name) else {
636 return Ok(Type::Any);
637 };
638 if !self.symbols.get_symbol(id)?.1.is_fn() {
639 return Err(Self::semantic_error(obj.span, format!("符号 {} 不是函数", name)));
640 }
641 let mut args = Vec::new();
642 for p in params {
643 args.push(self.infer_expr(p)?);
644 }
645 self.infer_fn(id, &args)
646 } else if obj.is_idx() {
647 let (target, _, method) = obj.clone().binary().unwrap();
648 let ty = self.infer_expr(&target)?;
649 if let Some(method) = self.get_value(&method) {
650 let method = method.as_str();
651 if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &ty
652 && let Some(ret_ty) = self.infer_list_method(&target, elem_ty, method, params)?
653 {
654 return Ok(ret_ty);
655 }
656 let fn_ty = match self.get_field(&ty, method) {
657 Ok((_, fn_ty)) => fn_ty,
658 Err(_) => {
659 let id = self.symbols.get_id(method)?;
660 if self.symbols.get_symbol(id)?.1.is_fn() {
661 Type::Symbol { id, params: Vec::new() }
662 } else {
663 return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
664 }
665 }
666 };
667 if let Type::Symbol { id, .. } = fn_ty {
668 let mut args = vec![ty];
669 for p in params {
670 args.push(self.infer_expr(p)?);
671 }
672 self.infer_fn(id, &args)
673 } else {
674 Ok(fn_ty)
675 }
676 } else {
677 Ok(Type::Any)
678 }
679 } else if let ExprKind::Var(idx) = &obj.kind {
680 let idx = self.top() + (*idx as usize);
681 if idx < self.tys.len()
682 && let Type::Symbol { id, .. } = self.tys[idx]
683 {
684 let mut args = Vec::new();
685 for p in params {
686 args.push(self.infer_expr(p)?);
687 }
688 self.infer_fn(id, &args)
689 } else {
690 Ok(Type::Any)
691 }
692 } else if obj.is_value() {
693 Ok(Type::Void)
694 } else {
695 Ok(Type::Any)
696 }
697 }
698 ExprKind::Typed { ty, .. } => self.symbols.get_type(ty),
699 ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
700 ExprKind::Repeat { value, len } => {
701 let value_ty = self.infer_expr(value)?;
702 let len = self.symbols.get_type(len).unwrap_or_else(|_| len.clone());
703 if let Type::ConstInt(len) = len {
704 let len = u32::try_from(len).map_err(|_| Self::semantic_error(expr.span, "重复数组长度必须是非负 u32"))?;
705 Ok(Type::Array(std::rc::Rc::new(value_ty), len))
706 } else {
707 Ok(Type::ArrayParam(std::rc::Rc::new(value_ty), std::rc::Rc::new(len)))
708 }
709 }
710 ExprKind::List(items) => {
711 if items.is_empty() {
712 return Ok(Type::list_any());
713 }
714 let mut elem_ty = Type::Any;
715 for item in items {
716 let item_ty = self.infer_expr(item)?;
717 elem_ty = if elem_ty.is_any() { item_ty } else { elem_ty + item_ty };
718 }
719 Ok(Type::Array(std::rc::Rc::new(elem_ty), items.len() as u32))
720 }
721 ExprKind::Range { start, stop, .. } => {
722 let start_ty = self.infer_expr(start)?;
723 let stop_ty = self.infer_expr(stop)?;
724 Ok(Self::merge_range_bound_types(start_ty, stop_ty))
725 }
726 _ => Ok(Type::Any),
727 }
728 }
729
730 fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
731 let mut fn_tys = Vec::new();
732 for (i, ty) in tys.iter().enumerate() {
733 if !ty.is_any() {
734 fn_tys.push(ty.clone());
735 } else if let Some(arg_ty) = arg_tys.get(i) {
736 fn_tys.push(self.symbols.get_type(arg_ty)?);
737 } else {
738 fn_tys.push(Type::Any);
739 }
740 }
741 Ok(fn_tys)
742 }
743
744 fn is_optimizable_local_ty(ty: &Type) -> bool {
745 ty.is_bool() || ty.is_native()
746 }
747
748 fn is_optimizable_list_elem_ty(ty: &Type) -> bool {
749 matches!(ty, Type::Bool | Type::U8 | Type::I8 | Type::U16 | Type::I16 | Type::U32 | Type::I32 | Type::F32 | Type::U64 | Type::I64 | Type::F64 | Type::Str)
750 }
751
752 fn local_type_hint_at(&self, pos: usize) -> Option<Type> {
753 let ty = self.tys.get(pos)?;
754 match ty {
755 Type::List(_) => self.list_elem_states.get(pos).cloned().flatten().and_then(|state| {
756 if let ListElemState::Known(elem_ty) = state
757 && Self::is_optimizable_list_elem_ty(&elem_ty)
758 {
759 Some(Type::List(std::rc::Rc::new(elem_ty)))
760 } else {
761 None
762 }
763 }),
764 ty if Self::is_optimizable_local_ty(ty) => Some(ty.clone()),
765 _ => None,
766 }
767 }
768
769 fn collect_local_type_hints(&self) -> Vec<Option<Type>> {
770 (self.top()..self.tys.len()).map(|pos| self.local_type_hint_at(pos)).collect()
771 }
772
773 fn set_local_type_hints(&mut self, id: u32, generic_args: &[Type], fn_tys: &[Type], hints: Vec<Option<Type>>) {
774 let items = self.local_type_hints.entry(id).or_default();
775 if let Some(item) = items.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys) {
776 item.2 = hints;
777 } else {
778 items.push((generic_args.to_vec(), fn_tys.to_vec(), hints));
779 }
780 }
781
782 pub fn inferred_local_type_hints(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Vec<Option<Type>> {
783 self.local_type_hints.get(&id).and_then(|items| items.iter().find(|item| item.0 == generic_args && item.1 == fn_tys)).map(|item| item.2.clone()).unwrap_or_default()
784 }
785
786 pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
787 self.infer_fn_with_params(id, arg_tys, &[])
788 }
789
790 pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
791 if self.infer_stack.len() > MAX_INFER_DEPTH {
794 return Ok(Type::Any);
795 }
796 let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
797 if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
798 if let Type::Fn { tys, ret: _ } = ty {
799 let resolved_generic_args = crate::resolve_generic_args_from_types(&generic_params, &tys, arg_tys, generic_args)?;
800 let generic_args = resolved_generic_args.as_slice();
801 let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
802 let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
803 let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
804 let body = if generic_params.is_empty() {
805 body
806 } else {
807 let mut compile_tys = tys.clone();
808 let mut compile_cap = cap.clone();
809 let saved_state = self.take_local_state();
810 if let Some((module, _)) = name.split_once("::") {
811 self.symbols.push_module_scope(module.into());
812 }
813 let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
814 if name.contains("::") {
815 self.symbols.pop_module_scope();
816 }
817 self.restore_local_state(saved_state);
818 Stmt::new(StmtKind::Block(compiled?), Span::default())
819 };
820 if let Some(fns) = self.fns.get_mut(&id) {
821 for f in fns.iter() {
822 if f.0 == generic_args && f.1 == fn_tys {
823 return match &f.2 {
824 FnInferRet::Done(ret_ty) => self.symbols.get_type(ret_ty),
825 FnInferRet::Pending(seed) => seed.as_ref().map(|ty| self.symbols.get_type(ty)).unwrap_or_else(|| {
826 if self.infer_stack.iter().any(|(sid, sargs, _)| *sid == id && sargs == generic_args) {
828 if let Some(base_ty) = self.try_find_base_return_ty(&body) {
829 return self.symbols.get_type(&base_ty);
830 }
831 }
832 Ok(Type::Any)
833 }),
834 };
835 }
836 }
837 fns.push((generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None)));
838 } else {
839 self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None))]);
840 }
841 if self.pending_return_seed(id, generic_args, &fn_tys).is_none() {
843 if let Some(base_ty) = self.try_find_base_return_ty_with_scope(&body, id, &name, &args, &fn_tys) {
844 if let Some(fns) = self.fns.get_mut(&id) {
845 if let Some(item) = fns.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys)
846 && let FnInferRet::Pending(seed) = &mut item.2
847 && seed.is_none()
848 {
849 *seed = Some(base_ty);
850 }
851 }
852 }
853 }
854 let mut ret_ty = None;
855 let mut local_type_hints = Vec::new();
856 for _ in 0..4 {
857 let before_seed = self.pending_return_seed(id, generic_args, &fn_tys);
858 let saved_state = self.take_local_state();
859 self.frames.push(0);
860 for (arg, ty) in args.iter().zip(fn_tys.iter()) {
861 self.add_name(arg.clone());
862 self.add_ty(ty.clone());
863 }
864 for c in cap.vars.iter() {
865 if let Some((name, ty)) = cap.names.get(*c) {
866 self.add_name(name.clone());
867 self.add_ty(ty.clone());
868 } else {
869 self.add_name("".into());
870 self.add_ty(Type::Any);
871 }
872 }
873 self.infer_stack.push((id, generic_args.to_vec(), fn_tys.clone()));
874 let pass_ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
875 self.infer_stack.pop();
876 let pass_local_type_hints = self.collect_local_type_hints();
877 self.restore_local_state(saved_state);
878 let pass_ret_ty = match pass_ret_ty {
879 Ok(pass_ret_ty) => self.symbols.get_type(&pass_ret_ty).unwrap_or(pass_ret_ty),
880 Err(err) => {
881 log::error!("infer_fn {} failed: {:?}", name, err);
882 let should_remove = self
883 .fns
884 .get_mut(&id)
885 .map(|fns| {
886 fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || !matches!(item.2, FnInferRet::Pending(_)));
887 fns.is_empty()
888 })
889 .unwrap_or(false);
890 if should_remove {
891 self.fns.remove(&id);
892 }
893 return Err(err);
894 }
895 };
896 if !pass_ret_ty.is_any() {
897 self.update_pending_return_seed(&pass_ret_ty);
898 ret_ty = Some(pass_ret_ty.clone());
899 } else if ret_ty.is_none() {
900 ret_ty = Some(pass_ret_ty);
901 }
902 local_type_hints = pass_local_type_hints;
903 let after_seed = self.pending_return_seed(id, generic_args, &fn_tys);
904 if before_seed == after_seed {
905 break;
906 }
907 }
908 let ret_ty = ret_ty.unwrap_or(Type::Any);
909 self.fns.get_mut(&id).map(|f| {
910 f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = FnInferRet::Done(ret_ty.clone()));
911 });
912 self.set_local_type_hints(id, generic_args, &fn_tys, local_type_hints);
913 if generic_args.is_empty()
914 && let Some((_, Symbol::Fn { ty: Type::Fn { ret, .. }, .. })) = self.symbols.get_symbol_mut(id)
915 && ret.is_any()
916 {
917 *ret = std::rc::Rc::new(ret_ty.clone());
918 }
919 Ok(ret_ty)
920 } else {
921 Ok(Type::Any)
922 }
923 } else if let Symbol::Native(f) = s {
924 if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
925 } else if matches!(s, Symbol::Null) {
926 Ok(Type::Any)
927 } else {
928 Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
929 }
930 }
931
932 pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
933 match &stmt.kind {
934 StmtKind::Expr(expr, close) => {
935 if !close {
936 self.infer_expr(expr)
937 } else {
938 self.infer_expr(expr)?;
939 Ok(Type::Void)
940 }
941 }
942 StmtKind::Return(expr) => {
943 if let Some(e) = expr {
944 self.infer_expr(e)
945 } else {
946 Ok(Type::Void)
947 }
948 }
949 StmtKind::Block(stmts) => {
950 for (idx, stmt) in stmts.iter().enumerate() {
951 let ty = self.infer_stmt(stmt)?;
952 if stmt.is_return() || idx == stmts.len() - 1 {
953 return Ok(ty);
954 }
955 }
956 Ok(Type::Void)
957 }
958 StmtKind::If { then_body, else_body, .. } => {
959 let then_ty = self.infer_stmt(then_body)?;
960 if let Some(e) = else_body {
961 let else_ty = self.infer_stmt(e)?;
962 if then_ty != else_ty {
963 log::debug!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
964 return Self::merge_return_type(stmt.span, Some(then_ty), else_ty);
965 }
966 }
967 if else_body.is_none() {
968 return Ok(Type::Void);
969 }
970 Ok(then_ty)
971 }
972 StmtKind::While { cond, body } => {
973 let cond_ty = self.infer_expr(cond)?;
974 if cond_ty != Type::Bool {
975 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
976 }
977 self.infer_stmt(body)
978 }
979 StmtKind::For { pat, range, body } => {
980 let ty = self.for_pattern_ty(range)?;
981 self.add_pattern_bindings_for_infer(pat, ty)?;
982 self.infer_stmt(body)
983 }
984 StmtKind::Let { pat, value } => {
985 let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
986 self.add_pattern_bindings_for_infer(pat, expr_ty)?;
987 Ok(Type::Void)
988 }
989 _ => Ok(Type::Void),
990 }
991 }
992}