1use super::{Compiler, FnInferRet, ListElemState, Symbol};
2use anyhow::Result;
3use dynamic::{Dynamic, Type};
4use parser::{BinaryOp, Expr, ExprKind, Pattern, PatternKind, Span, Stmt, StmtKind, UnaryOp};
5
6#[derive(Clone)]
7struct ReturnInfo {
8 ty: Type,
9 shape: Option<Type>,
10}
11
12impl Compiler {
13 fn current_infer_key(&self) -> Option<(u32, Vec<Type>, Vec<Type>)> {
14 self.infer_stack.last().cloned()
15 }
16
17 fn pending_return_seed(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Option<Type> {
18 self.fns.get(&id).and_then(|fns| {
19 fns.iter().find_map(|item| {
20 if item.0 == generic_args
21 && item.1 == fn_tys
22 && let FnInferRet::Pending(seed) = &item.2
23 {
24 seed.clone()
25 } else {
26 None
27 }
28 })
29 })
30 }
31
32 fn update_pending_return_seed(&mut self, ty: &Type) {
33 if ty.is_any() {
34 return;
35 }
36 let Some((id, generic_args, fn_tys)) = self.current_infer_key() else {
37 return;
38 };
39 let Some(fns) = self.fns.get_mut(&id) else {
40 return;
41 };
42 if let Some(item) = fns.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys)
43 && let FnInferRet::Pending(seed) = &mut item.2
44 {
45 let next = seed.take().map(|prev| prev + ty.clone()).unwrap_or_else(|| ty.clone());
46 *seed = Some(next);
47 }
48 }
49
50 fn add_pattern_bindings_for_infer(&mut self, pat: &Pattern, expr_ty: Type) -> Result<()> {
51 match &pat.kind {
52 PatternKind::Ident { name, ty } => {
53 let annotated_ty = self.symbols.get_type(ty)?;
54 self.add_name(name.clone());
55 self.add_ty(if annotated_ty.is_any() { expr_ty } else { annotated_ty });
56 }
57 PatternKind::Var { idx, .. } => self.set_ty(*idx, expr_ty),
58 PatternKind::Tuple(pats) => {
59 if let Type::Tuple(tys) = expr_ty {
60 for (pat, ty) in pats.iter().zip(tys) {
61 self.add_pattern_bindings_for_infer(pat, ty)?;
62 }
63 } else {
64 for pat in pats {
65 self.add_pattern_bindings_for_infer(pat, Type::Any)?;
66 }
67 }
68 }
69 PatternKind::List { elems, .. } => {
70 for pat in elems {
71 self.add_pattern_bindings_for_infer(pat, Type::Any)?;
72 }
73 }
74 PatternKind::Wildcard => {
75 self.add_name("".into());
76 self.add_ty(expr_ty);
77 }
78 PatternKind::Literal(_) | PatternKind::Member(_, _) | PatternKind::Idx(_, _) => {}
79 }
80 Ok(())
81 }
82
83 fn for_pattern_ty(&mut self, range: &Expr) -> Result<Type> {
84 if matches!(range.kind, ExprKind::Range { .. }) {
85 return self.infer_range_expr(range);
86 }
87 Ok(match self.infer_expr(range)? {
88 Type::Array(elem_ty, _) | Type::Vec(elem_ty, _) | Type::List(elem_ty) => elem_ty.as_ref().clone(),
89 _ => Type::Any,
90 })
91 }
92
93 fn infer_range_expr(&mut self, range: &Expr) -> Result<Type> {
94 let ExprKind::Range { start, stop, .. } = &range.kind else {
95 return self.infer_expr(range);
96 };
97 let start_ty = self.infer_expr(start)?;
98 let stop_ty = self.infer_expr(stop)?;
99 Ok(Self::merge_range_bound_types(start_ty, stop_ty))
100 }
101
102 fn merge_range_bound_types(start_ty: Type, stop_ty: Type) -> Type {
103 if start_ty.is_any() {
104 stop_ty
105 } else if stop_ty.is_any() {
106 start_ty
107 } else if start_ty == Type::I32 && stop_ty.is_uint() {
108 stop_ty
109 } else if stop_ty == Type::I32 && start_ty.is_uint() {
110 start_ty
111 } else {
112 start_ty + stop_ty
113 }
114 }
115
116 fn merge_return_type(span: Span, left: Option<Type>, right: Type) -> Result<Type> {
117 match left {
118 Some(left) if left == right => Ok(left),
119 Some(left) if left.is_void() || right.is_void() => Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left, right))),
120 Some(left) => Ok(left + right),
121 None => Ok(right),
122 }
123 }
124
125 fn return_shape(&self, expr: &Expr, ty: &Type) -> Option<Type> {
126 if !ty.is_any() {
127 return match ty {
128 Type::Struct { .. } => Some(ty.clone()),
129 Type::Map => Some(Type::Map),
130 Type::List(elem) | Type::Array(elem, _) => Some(Type::List(elem.clone())),
131 _ => None,
132 };
133 }
134 match &expr.kind {
135 ExprKind::List(_) | ExprKind::Tuple(_) => Some(Type::list_any()),
136 ExprKind::Dict(_) => Some(Type::Map),
137 ExprKind::Value(value) => Self::dynamic_return_shape(value.get_type()),
138 ExprKind::Const(idx) => self.consts.get(*idx).and_then(|value| Self::dynamic_return_shape(value.get_type())),
139 ExprKind::Typed { ty, .. } => Some(ty.clone()),
140 _ => None,
141 }
142 }
143
144 fn dynamic_return_shape(ty: Type) -> Option<Type> {
145 match ty {
146 Type::Map => Some(Type::Map),
147 Type::List(elem) => Some(Type::List(elem)),
148 Type::Array(elem, _) => Some(Type::List(elem)),
149 _ => None,
150 }
151 }
152
153 fn local_var_idx_for_expr(&self, expr: &Expr) -> Option<u32> {
154 match &expr.kind {
155 ExprKind::Var(idx) => Some(*idx),
156 ExprKind::Ident(name) => (self.top()..self.names.len()).rev().find(|idx| self.names[*idx].eq(name)).map(|idx| (idx - self.top()) as u32),
157 _ => None,
158 }
159 }
160
161 fn infer_list_method(&mut self, target: &Expr, elem_ty: &Type, method: &str, params: &[Expr]) -> Result<Option<Type>> {
162 match method {
163 "get_idx" | "pop" => Ok(Some(match self.local_var_idx_for_expr(target).and_then(|idx| self.list_elem_state(idx)) {
164 Some(ListElemState::Known(ty)) => ty,
165 Some(ListElemState::Unknown | ListElemState::Mixed) => Type::Any,
166 None => elem_ty.clone(),
167 })),
168 "push" => {
169 let pushed_ty = params
170 .first()
171 .map(|param| {
172 if let Some(value) = self.get_value(param)
173 && (value.is_str() || value.is_native())
174 {
175 Ok(value.get_type())
176 } else {
177 self.infer_expr(param)
178 }
179 })
180 .transpose()?
181 .unwrap_or(Type::Any);
182 if let Some(idx) = self.local_var_idx_for_expr(target) {
183 let state = self.list_elem_state(idx).unwrap_or_else(|| if elem_ty.is_any() { ListElemState::Unknown } else { ListElemState::Known(elem_ty.clone()) });
184 let next_state = match state {
185 ListElemState::Unknown if pushed_ty.is_any() => ListElemState::Mixed,
186 ListElemState::Unknown => ListElemState::Known(pushed_ty),
187 ListElemState::Known(_) if pushed_ty.is_any() => ListElemState::Mixed,
188 ListElemState::Known(prev) => {
189 let merged = if prev == pushed_ty {
190 prev
191 } else if (prev.is_int() || prev.is_uint() || prev.is_float()) && (pushed_ty.is_int() || pushed_ty.is_uint() || pushed_ty.is_float()) {
192 prev + pushed_ty
193 } else {
194 Type::Any
195 };
196 if merged.is_any() { ListElemState::Mixed } else { ListElemState::Known(merged) }
197 }
198 ListElemState::Mixed => ListElemState::Mixed,
199 };
200 let next_elem = if let ListElemState::Known(ty) = &next_state { ty.clone() } else { Type::Any };
201 self.set_ty(idx, Type::List(std::rc::Rc::new(next_elem)));
202 self.set_list_elem_state(idx, Some(next_state));
203 }
204 Ok(Some(Type::Void))
205 }
206 "len" => Ok(Some(Type::I32)),
207 "is_list" | "is_null" => Ok(Some(Type::Bool)),
208 _ => Ok(None),
209 }
210 }
211
212 fn infer_return_expr(&mut self, expr: &Expr) -> Result<ReturnInfo> {
213 let ty = self.infer_expr(expr)?;
214 let shape = self.return_shape(expr, &ty);
215 let ty = if matches!(shape, Some(Type::Map | Type::List(_))) { Type::Any } else { ty };
216 Ok(ReturnInfo { ty, shape })
217 }
218
219 fn merge_return_info(span: Span, left: Option<ReturnInfo>, right: ReturnInfo) -> Result<ReturnInfo> {
220 let Some(left) = left else {
221 return Ok(right);
222 };
223 if let (Some(left_shape), Some(right_shape)) = (&left.shape, &right.shape)
224 && left_shape != right_shape
225 {
226 return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, right_shape)));
227 }
228 if let Some(left_shape) = &left.shape
229 && left_shape.is_struct()
230 && right.ty.is_any()
231 && right.shape.is_none()
232 {
233 return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", left_shape, Type::Any)));
234 }
235 if let Some(right_shape) = &right.shape
236 && right_shape.is_struct()
237 && left.ty.is_any()
238 && left.shape.is_none()
239 {
240 return Err(Self::semantic_error(span, format!("返回类型不一致: {:?} 和 {:?}", Type::Any, right_shape)));
241 }
242 let ty = Self::merge_return_type(span, Some(left.ty), right.ty)?;
243 Ok(ReturnInfo { ty, shape: left.shape.or(right.shape) })
244 }
245
246 fn infer_return_type(&mut self, stmt: &Stmt) -> Result<Option<Type>> {
247 self.infer_returns(stmt, true).map(|(info, _)| info.map(|info| info.ty))
248 }
249
250 pub(crate) fn check_return_type(&mut self, stmt: &Stmt) -> Result<()> {
251 self.infer_returns(stmt, true).map(|_| ())
252 }
253
254 fn infer_returns(&mut self, stmt: &Stmt, tail: bool) -> Result<(Option<ReturnInfo>, bool)> {
255 match &stmt.kind {
256 StmtKind::Return(Some(expr)) => Ok((Some(self.infer_return_expr(expr)?), true)),
257 StmtKind::Return(None) => Ok((Some(ReturnInfo { ty: Type::Void, shape: Some(Type::Void) }), true)),
258 StmtKind::Block(stmts) => {
259 let mut ret = None;
260 for (idx, stmt) in stmts.iter().enumerate() {
261 let (info, always_returns) = self.infer_returns(stmt, tail && idx == stmts.len().saturating_sub(1))?;
262 if let Some(info) = info {
263 self.update_pending_return_seed(&info.ty);
264 ret = Some(Self::merge_return_info(stmt.span, ret, info)?);
265 if let Some(ret) = &ret {
266 self.update_pending_return_seed(&ret.ty);
267 }
268 }
269 if always_returns {
270 return Ok((ret, true));
271 }
272 }
273 Ok((ret, false))
274 }
275 StmtKind::If { cond, then_body, else_body } => {
276 let cond_ty = self.infer_expr(cond)?;
277 if cond_ty != Type::Bool {
278 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
279 }
280 let (mut ret, then_returns) = self.infer_returns(then_body, tail)?;
281 if let Some(ret) = &ret {
282 self.update_pending_return_seed(&ret.ty);
283 }
284 let else_returns = if let Some(body) = else_body {
285 let (else_ty, else_returns) = self.infer_returns(body, tail)?;
286 if let Some(info) = else_ty {
287 self.update_pending_return_seed(&info.ty);
288 ret = Some(Self::merge_return_info(body.span, ret, info)?);
289 if let Some(ret) = &ret {
290 self.update_pending_return_seed(&ret.ty);
291 }
292 }
293 else_returns
294 } else {
295 false
296 };
297 Ok((ret, then_returns && else_returns))
298 }
299 StmtKind::While { cond, body } => {
300 let cond_ty = self.infer_expr(cond)?;
301 if cond_ty != Type::Bool {
302 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
303 }
304 self.infer_returns(body, false).map(|(ty, _)| (ty, false))
305 }
306 StmtKind::Loop(body) => self.infer_returns(body, false),
307 StmtKind::For { pat, range, body } => {
308 let ty = self.for_pattern_ty(range)?;
309 self.add_pattern_bindings_for_infer(pat, ty)?;
310 self.infer_returns(body, false).map(|(ty, _)| (ty, false))
311 }
312 StmtKind::Let { .. } => {
313 self.infer_stmt(stmt)?;
314 Ok((None, false))
315 }
316 StmtKind::Expr(expr, close) => {
317 let info = self.infer_return_expr(expr)?;
318 Ok(if *close || !tail { (None, false) } else { (Some(info), true) })
319 }
320 _ => {
321 self.infer_stmt(stmt)?;
322 Ok((None, false))
323 }
324 }
325 }
326
327 pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
328 match &expr.kind {
329 ExprKind::Value(Dynamic::Null) => Ok(Type::Any),
330 ExprKind::Value(v) if v.is_list() => Ok(v.get_type()),
331 ExprKind::Value(v) if v.is_map() => Ok(Type::Any),
332 ExprKind::Value(v) => Ok(v.get_type()),
333 ExprKind::Const(idx) => Ok(if self.consts.get(*idx).is_some_and(|value| value.is_list() && value.len() == 0) { Type::list_any() } else { Type::Any }),
334 ExprKind::Var(idx) => {
335 let idx = self.top() + (*idx as usize);
336 if idx < self.tys.len() { self.symbols.get_type(&self.tys[idx]) } else { Ok(Type::Any) }
337 }
338 ExprKind::Ident(ident) => {
339 for idx in (self.top()..self.names.len()).rev() {
340 if self.names[idx].eq(ident) && idx < self.tys.len() {
341 return self.symbols.get_type(&self.tys[idx]);
342 }
343 }
344 let id = self.symbols.get_id(ident).map_err(|_| Self::semantic_error(expr.span, format!("未找到标识符 {}", ident)))?;
345 match self.symbols.get_symbol(id)?.1 {
346 Symbol::Const { ty, .. } => Ok(ty.clone()),
347 Symbol::Static { ty, .. } => Ok(ty.clone()),
348 Symbol::Struct(ty, _) => Ok(ty.clone()),
349 Symbol::Fn { .. } => Ok(Type::Symbol { id, params: Vec::new() }),
350 Symbol::Native(ty) => Ok(ty.clone()),
351 s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
352 }
353 }
354 ExprKind::Id(id, _) => match self.symbols.get_symbol(*id)?.1 {
355 Symbol::Const { ty, .. } => Ok(ty.clone()),
356 Symbol::Static { ty, .. } => Ok(ty.clone()),
357 Symbol::Struct(ty, _) => Ok(ty.clone()),
358 Symbol::Fn { .. } => Ok(Type::Symbol { id: *id, params: Vec::new() }),
359 Symbol::Native(ty) => Ok(ty.clone()),
360 s => Err(Self::semantic_error(expr.span, format!("符号 {:?} 不是变量、常量、静态变量、结构体", s))),
361 },
362 ExprKind::Generic { obj, params } => {
363 let params = params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect();
364 match self.infer_expr(obj)? {
365 Type::Symbol { id, .. } => Ok(Type::Symbol { id, params }),
366 _ => Ok(Type::Any),
367 }
368 }
369 ExprKind::AssocId { id, params } => Ok(Type::Symbol { id: *id, params: params.clone() }),
370 ExprKind::Unary { op, value } => match op {
371 UnaryOp::Not => {
372 let ty = self.infer_expr(value.as_ref())?;
373 if ty.is_int() || ty.is_uint() { Ok(ty) } else { Ok(Type::Bool) }
374 }
375 UnaryOp::Neg => self.infer_expr(value.as_ref()),
376 UnaryOp::Unknow => Ok(Type::Any),
377 },
378 ExprKind::Binary { left, op, right } => {
379 if op == &BinaryOp::Assign
380 && let ExprKind::Tuple(left_items) | ExprKind::List(left_items) = &left.kind
381 {
382 if let ExprKind::Tuple(right_items) | ExprKind::List(right_items) = &right.kind {
383 if left_items.len() != right_items.len() {
384 return Err(Self::semantic_error(expr.span, format!("多重赋值数量不匹配: 左侧 {} 个,右侧 {} 个", left_items.len(), right_items.len())));
385 }
386 for item in right_items {
387 let _ = self.infer_expr(item)?;
388 }
389 } else {
390 let _ = self.infer_expr(right)?;
391 }
392 return Ok(Type::Void);
393 }
394 let assign_idx = if op.is_assign() { if let ExprKind::Var(idx) = &left.kind { Some(*idx) } else { None } } else { None };
395 let ty = if op.is_logic() {
396 Type::Bool
397 } else if op == &BinaryOp::Idx {
398 let left_ty = self.infer_expr(left)?;
399 if let Type::Array(elem_ty, _) = left_ty {
400 (*elem_ty).clone()
401 } else if let Type::Vec(elem_ty, _) = left_ty {
402 (*elem_ty).clone()
403 } else if let Type::List(elem_ty) = left_ty {
404 (*elem_ty).clone()
405 } else {
406 let left_ty = self.symbols.get_type(&left_ty)?;
407 let right_ty = if right.is_value() || right.is_const() {
408 let right_value = if let ExprKind::Const(c) = &right.kind { self.consts[*c].clone() } else { right.clone().value()? };
409 if right_value.is_str() {
410 if left_ty.is_any() {
411 return Ok(Type::Any);
412 }
413 if let Ok(field) = self.symbols.get_field(&left_ty, right_value.as_str()) {
414 return if let Type::Fn { ret, .. } = field.1 { Ok(ret.as_ref().clone()) } else { Ok(field.1.clone()) };
415 }
416 } else if let Type::Struct { fields, .. } = &left_ty
417 && let Some(idx) = right_value.as_int()
418 {
419 return fields.get(idx as usize).map(|(_, ty)| ty.clone()).ok_or_else(|| Self::semantic_error(right.span, format!("结构字段索引越界 {}", idx)));
420 }
421 right_value.get_type()
422 } else {
423 self.infer_expr(right)?
424 };
425 if right_ty.is_int() || right_ty.is_uint() {
426 if left_ty.is_any() {
427 return Ok(Type::Any);
428 }
429 let (_, s) = self.symbols.get_field(&left_ty, "get_idx")?;
430 let fn_ty = self.symbols.get_type(&s)?;
431 return if let Type::Fn { ret, .. } = &fn_ty { Ok(ret.as_ref().clone()) } else { Ok(fn_ty) };
432 }
433 if left_ty.is_any() {
434 return Ok(Type::Any);
435 }
436 Type::Any
437 }
438 } else {
439 let left_ty = self.infer_expr(left)?;
440 let right_ty = self.infer_expr(right)?;
441 if op == &BinaryOp::Assign {
442 if !left_ty.is_any() && right_ty.is_any() { left_ty } else { right_ty }
443 } else if op.is_assign() && !left_ty.is_any() && right_ty.is_any() {
444 left_ty
445 } else {
446 left_ty + right_ty
447 }
448 };
449 assign_idx.map(|idx| self.set_ty(idx, ty.clone()));
450 Ok(ty)
451 }
452 ExprKind::Call { obj, params } => {
453 if let ExprKind::Assoc { ty, name } = &obj.kind {
454 let base_name = match ty {
455 Type::Ident { name, .. } => name.clone(),
456 Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
457 _ => return Ok(Type::Any),
458 };
459 let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
460 let generic_args = match ty {
461 Type::Ident { params, .. } | Type::Symbol { params, .. } => params.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>(),
462 _ => Vec::new(),
463 };
464 let mut args = Vec::new();
465 for p in params {
466 args.push(self.infer_expr(p)?);
467 }
468 self.infer_fn_with_params(id, &args, &generic_args)
469 } else if let ExprKind::AssocId { id, params: generic_args } = &obj.kind {
470 let mut args = Vec::new();
471 for p in params {
472 args.push(self.infer_expr(p)?);
473 }
474 self.infer_fn_with_params(*id, &args, generic_args)
475 } else if let ExprKind::Generic { obj, params: generic_args } = &obj.kind {
476 let Type::Symbol { id, .. } = self.infer_expr(obj)? else {
477 return Ok(Type::Any);
478 };
479 let generic_args = generic_args.iter().map(|param| self.symbols.get_type(param).unwrap_or_else(|_| param.clone())).collect::<Vec<_>>();
480 let mut args = Vec::new();
481 for p in params {
482 args.push(self.infer_expr(p)?);
483 }
484 self.infer_fn_with_params(id, &args, &generic_args)
485 } else if let ExprKind::TypedMethod { obj: target, ty, name } = &obj.kind {
486 let base_name = match ty {
487 Type::Ident { name, .. } => name.clone(),
488 Type::Symbol { id, .. } => self.symbols.get_symbol(*id)?.0.clone(),
489 _ => return Ok(Type::Any),
490 };
491 let id = self.symbols.get_id(&format!("{}::{}", base_name, name))?;
492 let mut args = vec![self.infer_expr(target)?];
493 for p in params {
494 args.push(self.infer_expr(p)?);
495 }
496 self.infer_fn(id, &args)
497 } else if let ExprKind::Id(id, obj_expr) = &obj.kind {
498 let method = self.symbols.get_symbol(*id).ok().and_then(|(name, _)| name.rsplit_once("::").map(|(_, method)| method.to_string()));
499 if let Some(target) = obj_expr
500 && let Some(method) = method
501 {
502 let target_ty = self.infer_expr(target)?;
503 if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &target_ty
504 && let Some(ret_ty) = self.infer_list_method(target, elem_ty, method.as_str(), params)?
505 {
506 return Ok(ret_ty);
507 }
508 }
509 let mut args: Vec<Type> = if let Some(obj) = obj_expr { vec![self.infer_expr(obj)?] } else { Vec::new() };
510 for p in params {
511 args.push(self.infer_expr(p)?);
512 }
513 self.infer_fn(*id, &args)
514 } else if let ExprKind::Ident(name) = &obj.kind {
515 for idx in (self.top()..self.names.len()).rev() {
516 if self.names[idx].eq(name) && idx < self.tys.len() {
517 return if let Type::Symbol { id, .. } = &self.tys[idx] {
518 let id = *id;
519 let mut args = Vec::new();
520 for p in params {
521 args.push(self.infer_expr(p)?);
522 }
523 self.infer_fn(id, &args)
524 } else {
525 Ok(Type::Any)
526 };
527 }
528 }
529 let Ok(id) = self.symbols.get_id(name) else {
530 return Ok(Type::Any);
531 };
532 if !self.symbols.get_symbol(id)?.1.is_fn() {
533 return Err(Self::semantic_error(obj.span, format!("符号 {} 不是函数", name)));
534 }
535 let mut args = Vec::new();
536 for p in params {
537 args.push(self.infer_expr(p)?);
538 }
539 self.infer_fn(id, &args)
540 } else if obj.is_idx() {
541 let (target, _, method) = obj.clone().binary().unwrap();
542 let ty = self.infer_expr(&target)?;
543 if let Some(method) = self.get_value(&method) {
544 let method = method.as_str();
545 if let Type::List(elem_ty) | Type::Array(elem_ty, _) = &ty
546 && let Some(ret_ty) = self.infer_list_method(&target, elem_ty, method, params)?
547 {
548 return Ok(ret_ty);
549 }
550 let fn_ty = match self.get_field(&ty, method) {
551 Ok((_, fn_ty)) => fn_ty,
552 Err(_) => {
553 let id = self.symbols.get_id(method)?;
554 if self.symbols.get_symbol(id)?.1.is_fn() {
555 Type::Symbol { id, params: Vec::new() }
556 } else {
557 return Err(Self::semantic_error(obj.span, format!("符号 {method} 不是函数")));
558 }
559 }
560 };
561 if let Type::Symbol { id, .. } = fn_ty {
562 let mut args = vec![ty];
563 for p in params {
564 args.push(self.infer_expr(p)?);
565 }
566 self.infer_fn(id, &args)
567 } else {
568 Ok(fn_ty)
569 }
570 } else {
571 Ok(Type::Any)
572 }
573 } else if let ExprKind::Var(idx) = &obj.kind {
574 let idx = self.top() + (*idx as usize);
575 if idx < self.tys.len()
576 && let Type::Symbol { id, .. } = self.tys[idx]
577 {
578 let mut args = Vec::new();
579 for p in params {
580 args.push(self.infer_expr(p)?);
581 }
582 self.infer_fn(id, &args)
583 } else {
584 Ok(Type::Any)
585 }
586 } else if obj.is_value() {
587 Ok(Type::Void)
588 } else {
589 Ok(Type::Any)
590 }
591 }
592 ExprKind::Typed { ty, .. } => self.symbols.get_type(ty),
593 ExprKind::Stmt(stmt) => self.infer_stmt(stmt),
594 ExprKind::Repeat { value, len } => {
595 let value_ty = self.infer_expr(value)?;
596 let len = self.symbols.get_type(len).unwrap_or_else(|_| len.clone());
597 if let Type::ConstInt(len) = len {
598 let len = u32::try_from(len).map_err(|_| Self::semantic_error(expr.span, "重复数组长度必须是非负 u32"))?;
599 Ok(Type::Array(std::rc::Rc::new(value_ty), len))
600 } else {
601 Ok(Type::ArrayParam(std::rc::Rc::new(value_ty), std::rc::Rc::new(len)))
602 }
603 }
604 ExprKind::List(items) => {
605 if items.is_empty() {
606 return Ok(Type::list_any());
607 }
608 let mut elem_ty = Type::Any;
609 for item in items {
610 let item_ty = self.infer_expr(item)?;
611 elem_ty = if elem_ty.is_any() { item_ty } else { elem_ty + item_ty };
612 }
613 Ok(Type::Array(std::rc::Rc::new(elem_ty), items.len() as u32))
614 }
615 ExprKind::Range { start, stop, .. } => {
616 let start_ty = self.infer_expr(start)?;
617 let stop_ty = self.infer_expr(stop)?;
618 Ok(Self::merge_range_bound_types(start_ty, stop_ty))
619 }
620 _ => Ok(Type::Any),
621 }
622 }
623
624 fn get_fn_tys(&mut self, tys: &[Type], arg_tys: &[Type]) -> Result<Vec<Type>> {
625 let mut fn_tys = Vec::new();
626 for (i, ty) in tys.iter().enumerate() {
627 if !ty.is_any() {
628 fn_tys.push(ty.clone());
629 } else if let Some(arg_ty) = arg_tys.get(i) {
630 fn_tys.push(self.symbols.get_type(arg_ty)?);
631 } else {
632 fn_tys.push(Type::Any);
633 }
634 }
635 Ok(fn_tys)
636 }
637
638 fn is_optimizable_local_ty(ty: &Type) -> bool {
639 ty.is_bool() || ty.is_native()
640 }
641
642 fn is_optimizable_list_elem_ty(ty: &Type) -> bool {
643 matches!(ty, Type::Bool | Type::U8 | Type::I8 | Type::U16 | Type::I16 | Type::U32 | Type::I32 | Type::F32 | Type::U64 | Type::I64 | Type::F64 | Type::Str)
644 }
645
646 fn local_type_hint_at(&self, pos: usize) -> Option<Type> {
647 let ty = self.tys.get(pos)?;
648 match ty {
649 Type::List(_) => self.list_elem_states.get(pos).cloned().flatten().and_then(|state| {
650 if let ListElemState::Known(elem_ty) = state
651 && Self::is_optimizable_list_elem_ty(&elem_ty)
652 {
653 Some(Type::List(std::rc::Rc::new(elem_ty)))
654 } else {
655 None
656 }
657 }),
658 ty if Self::is_optimizable_local_ty(ty) => Some(ty.clone()),
659 _ => None,
660 }
661 }
662
663 fn collect_local_type_hints(&self) -> Vec<Option<Type>> {
664 (self.top()..self.tys.len()).map(|pos| self.local_type_hint_at(pos)).collect()
665 }
666
667 fn set_local_type_hints(&mut self, id: u32, generic_args: &[Type], fn_tys: &[Type], hints: Vec<Option<Type>>) {
668 let items = self.local_type_hints.entry(id).or_default();
669 if let Some(item) = items.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys) {
670 item.2 = hints;
671 } else {
672 items.push((generic_args.to_vec(), fn_tys.to_vec(), hints));
673 }
674 }
675
676 pub fn inferred_local_type_hints(&self, id: u32, generic_args: &[Type], fn_tys: &[Type]) -> Vec<Option<Type>> {
677 self.local_type_hints.get(&id).and_then(|items| items.iter().find(|item| item.0 == generic_args && item.1 == fn_tys)).map(|item| item.2.clone()).unwrap_or_default()
678 }
679
680 pub fn infer_fn(&mut self, id: u32, arg_tys: &[Type]) -> Result<Type> {
681 self.infer_fn_with_params(id, arg_tys, &[])
682 }
683
684 pub fn infer_fn_with_params(&mut self, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<Type> {
685 let (name, s) = self.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
686 if let Symbol::Fn { ty, args, generic_params, cap, body, .. } = s {
687 if let Type::Fn { tys, ret: _ } = ty {
688 let resolved_generic_args = crate::resolve_generic_args_from_types(&generic_params, &tys, arg_tys, generic_args)?;
689 let generic_args = resolved_generic_args.as_slice();
690 let tys = if generic_params.is_empty() { tys } else { tys.iter().map(|ty| crate::substitute_type(ty, &generic_params, generic_args)).collect() };
691 let body = if generic_params.is_empty() { body.as_ref().clone() } else { crate::substitute_stmt(body.as_ref(), &generic_params, generic_args) };
692 let fn_tys = self.get_fn_tys(&tys, arg_tys)?;
693 let body = if generic_params.is_empty() {
694 body
695 } else {
696 let mut compile_tys = tys.clone();
697 let mut compile_cap = cap.clone();
698 let saved_state = self.take_local_state();
699 if let Some((module, _)) = name.split_once("::") {
700 self.symbols.push_module_scope(module.into());
701 }
702 let compiled = self.compile_fn(&args, &mut compile_tys, body, &mut compile_cap);
703 if name.contains("::") {
704 self.symbols.pop_module_scope();
705 }
706 self.restore_local_state(saved_state);
707 Stmt::new(StmtKind::Block(compiled?), Span::default())
708 };
709 if let Some(fns) = self.fns.get_mut(&id) {
710 for f in fns.iter() {
711 if f.0 == generic_args && f.1 == fn_tys {
712 return match &f.2 {
713 FnInferRet::Done(ret_ty) => self.symbols.get_type(ret_ty),
714 FnInferRet::Pending(seed) => seed.as_ref().map(|ty| self.symbols.get_type(ty)).unwrap_or(Ok(Type::Any)),
715 };
716 }
717 }
718 fns.push((generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None)));
719 } else {
720 self.fns.insert(id, vec![(generic_args.to_vec(), fn_tys.clone(), FnInferRet::Pending(None))]);
721 }
722 let mut ret_ty = None;
723 let mut local_type_hints = Vec::new();
724 for _ in 0..4 {
725 let before_seed = self.pending_return_seed(id, generic_args, &fn_tys);
726 let saved_state = self.take_local_state();
727 self.frames.push(0);
728 for (arg, ty) in args.iter().zip(fn_tys.iter()) {
729 self.add_name(arg.clone());
730 self.add_ty(ty.clone());
731 }
732 for c in cap.vars.iter() {
733 if let Some((name, ty)) = cap.names.get(*c) {
734 self.add_name(name.clone());
735 self.add_ty(ty.clone());
736 } else {
737 self.add_name("".into());
738 self.add_ty(Type::Any);
739 }
740 }
741 self.infer_stack.push((id, generic_args.to_vec(), fn_tys.clone()));
742 let pass_ret_ty = self.infer_return_type(&body).map(|ty| ty.unwrap_or(Type::Void));
743 self.infer_stack.pop();
744 let pass_local_type_hints = self.collect_local_type_hints();
745 self.restore_local_state(saved_state);
746 let pass_ret_ty = match pass_ret_ty {
747 Ok(pass_ret_ty) => self.symbols.get_type(&pass_ret_ty).unwrap_or(pass_ret_ty),
748 Err(err) => {
749 log::error!("infer_fn {} failed: {:?}", name, err);
750 let should_remove = self
751 .fns
752 .get_mut(&id)
753 .map(|fns| {
754 fns.retain(|item| item.0 != generic_args || item.1 != fn_tys || !matches!(item.2, FnInferRet::Pending(_)));
755 fns.is_empty()
756 })
757 .unwrap_or(false);
758 if should_remove {
759 self.fns.remove(&id);
760 }
761 return Err(err);
762 }
763 };
764 if !pass_ret_ty.is_any() {
765 self.update_pending_return_seed(&pass_ret_ty);
766 ret_ty = Some(pass_ret_ty.clone());
767 } else if ret_ty.is_none() {
768 ret_ty = Some(pass_ret_ty);
769 }
770 local_type_hints = pass_local_type_hints;
771 let after_seed = self.pending_return_seed(id, generic_args, &fn_tys);
772 if before_seed == after_seed {
773 break;
774 }
775 }
776 let ret_ty = ret_ty.unwrap_or(Type::Any);
777 self.fns.get_mut(&id).map(|f| {
778 f.iter_mut().find(|item| item.0 == generic_args && item.1 == fn_tys).map(|item| item.2 = FnInferRet::Done(ret_ty.clone()));
779 });
780 self.set_local_type_hints(id, generic_args, &fn_tys, local_type_hints);
781 if generic_args.is_empty()
782 && let Some((_, Symbol::Fn { ty: Type::Fn { ret, .. }, .. })) = self.symbols.get_symbol_mut(id)
783 && ret.is_any()
784 {
785 *ret = std::rc::Rc::new(ret_ty.clone());
786 }
787 Ok(ret_ty)
788 } else {
789 Ok(Type::Any)
790 }
791 } else if let Symbol::Native(f) = s {
792 if let Type::Fn { ret, .. } = f { Ok((*ret).clone()) } else { Ok(Type::Any) }
793 } else if matches!(s, Symbol::Null) {
794 Ok(Type::Any)
795 } else {
796 Err(Self::semantic_error(Span::default(), format!("符号 {:?} 不是函数", name)))
797 }
798 }
799
800 pub fn infer_stmt(&mut self, stmt: &Stmt) -> Result<Type> {
801 match &stmt.kind {
802 StmtKind::Expr(expr, close) => {
803 if !close {
804 self.infer_expr(expr)
805 } else {
806 self.infer_expr(expr)?;
807 Ok(Type::Void)
808 }
809 }
810 StmtKind::Return(expr) => {
811 if let Some(e) = expr {
812 self.infer_expr(e)
813 } else {
814 Ok(Type::Void)
815 }
816 }
817 StmtKind::Block(stmts) => {
818 for (idx, stmt) in stmts.iter().enumerate() {
819 let ty = self.infer_stmt(stmt)?;
820 if stmt.is_return() || idx == stmts.len() - 1 {
821 return Ok(ty);
822 }
823 }
824 Ok(Type::Void)
825 }
826 StmtKind::If { then_body, else_body, .. } => {
827 let then_ty = self.infer_stmt(then_body)?;
828 if let Some(e) = else_body {
829 let else_ty = self.infer_stmt(e)?;
830 if then_ty != else_ty {
831 log::info!("then 和 else 有不同类型 {:?} {:?}", then_ty, else_ty);
832 return Ok(if then_ty.is_any() { else_ty } else { then_ty });
833 }
834 }
835 if else_body.is_none() {
836 return Ok(Type::Void);
837 }
838 Ok(then_ty)
839 }
840 StmtKind::While { cond, body } => {
841 let cond_ty = self.infer_expr(cond)?;
842 if cond_ty != Type::Bool {
843 return Err(Self::semantic_error(cond.span, format!("条件表达式必须是布尔类型,实际是 {:?}", cond_ty)));
844 }
845 self.infer_stmt(body)
846 }
847 StmtKind::For { pat, range, body } => {
848 let ty = self.for_pattern_ty(range)?;
849 self.add_pattern_bindings_for_infer(pat, ty)?;
850 self.infer_stmt(body)
851 }
852 StmtKind::Let { pat, value } => {
853 let expr_ty = if let StmtKind::Expr(expr, _) = &value.kind { self.infer_expr(expr)? } else { self.infer_stmt(value)? };
854 self.add_pattern_bindings_for_infer(pat, expr_ty)?;
855 Ok(Type::Void)
856 }
857 _ => Ok(Type::Void),
858 }
859 }
860}