Skip to main content

vm/
fns.rs

1use crate::{JITRunTime, context::BuildContext, rt::PendingFn};
2use anyhow::{Context, Result, anyhow};
3use compiler::{Symbol, resolve_generic_args_from_types, substitute_stmt, substitute_type};
4use cranelift::{codegen::ir::FuncRef, prelude::*};
5use cranelift_module::{FuncId, Module};
6use dynamic::Type;
7
8#[derive(Debug)]
9pub struct CompiledVariant {
10    generic_args: Vec<Type>,
11    ty: Type,
12    fn_id: FuncId,
13}
14
15#[derive(Debug)]
16pub enum FnVariant {
17    Native { ty: Type, fn_id: FuncId, context: Option<usize> },                                                        //没有变体 直接调用的原生函数
18    Inline { fn_ptr: fn(Option<&mut BuildContext>, Vec<Value>) -> Result<(Option<Value>, Type)>, arg_tys: Vec<Type> }, //inline 函数 直接生成代码
19    Compiled(Vec<CompiledVariant>),
20}
21
22impl FnVariant {
23    pub fn is_compiled(&self) -> bool {
24        if let Self::Compiled(_) = self { true } else { false }
25    }
26}
27
28use crate::get_type;
29use cranelift_module::Linkage;
30use parser::{Expr, ExprKind, Span, Stmt, StmtKind};
31use smol_str::SmolStr;
32
33#[derive(Debug)]
34pub enum FnInfo {
35    //用来调用的函数信息
36    Call { fn_id: FuncId, arg_tys: Vec<Type>, caps: Vec<usize>, ret: Type, context: Option<usize> },
37    Inline { fn_ptr: fn(Option<&mut BuildContext>, Vec<Value>) -> Result<(Option<Value>, Type)>, arg_tys: Vec<Type> },
38}
39
40impl FnInfo {
41    pub fn get_id(&self) -> Result<FuncId> {
42        if let Self::Call { fn_id, arg_tys: _, caps: _, ret: _, context: _ } = self { Ok(*fn_id) } else { Err(anyhow!("Inline 函数没有 FuncId")) }
43    }
44
45    pub fn arg_tys(&self) -> Result<&[Type]> {
46        match self {
47            Self::Call { fn_id: _, arg_tys, caps: _, ret: _, context: _ } => Ok(arg_tys),
48            Self::Inline { fn_ptr: _, arg_tys } => Ok(arg_tys),
49        }
50    }
51
52    pub fn get_type(&self) -> Result<Type> {
53        match self {
54            Self::Call { fn_id: _, arg_tys: _, caps: _, ret, context: _ } => Ok(ret.clone()),
55            Self::Inline { fn_ptr, arg_tys: _ } => fn_ptr(None, vec![]).map(|(_, t)| t),
56        }
57    }
58}
59
60impl JITRunTime {
61    fn coerce_returns(stmt: &Stmt, ret_ty: &Type) -> Stmt {
62        let kind = match &stmt.kind {
63            StmtKind::Return(Some(expr)) if ret_ty.is_void() => StmtKind::Return(None),
64            StmtKind::Return(Some(expr)) => StmtKind::Return(Some(Expr::new(ExprKind::Typed { value: Box::new(expr.clone()), ty: ret_ty.clone() }, expr.span))),
65            StmtKind::Block(stmts) => StmtKind::Block(stmts.iter().map(|stmt| Self::coerce_returns(stmt, ret_ty)).collect()),
66            StmtKind::If { cond, then_body, else_body } => {
67                StmtKind::If { cond: cond.clone(), then_body: Box::new(Self::coerce_returns(then_body, ret_ty)), else_body: else_body.as_ref().map(|body| Box::new(Self::coerce_returns(body, ret_ty))) }
68            }
69            StmtKind::While { cond, body } => StmtKind::While { cond: cond.clone(), body: Box::new(Self::coerce_returns(body, ret_ty)) },
70            StmtKind::Loop(body) => StmtKind::Loop(Box::new(Self::coerce_returns(body, ret_ty))),
71            StmtKind::For { pat, range, body } => StmtKind::For { pat: pat.clone(), range: range.clone(), body: Box::new(Self::coerce_returns(body, ret_ty)) },
72            _ => stmt.kind.clone(),
73        };
74        Stmt::new(kind, stmt.span)
75    }
76
77    pub fn add_inline(&mut self, name: &str, args: Vec<Type>, ret: Type, f: fn(Option<&mut BuildContext>, Vec<Value>) -> Result<(Option<Value>, Type)>) -> Result<u32> {
78        let id = self.compiler.add_symbol(name, Symbol::native(args.clone(), ret));
79        self.fns.insert(id, FnVariant::Inline { fn_ptr: f.into(), arg_tys: args });
80        if let Some((def, method)) = name.split_once("::") {
81            let def_id = self.get_id(def)?;
82            if let Some((_, define)) = self.compiler.symbols.get_symbol_mut(def_id) {
83                if let Symbol::Struct(Type::Struct { params, fields }, _) = define {
84                    fields.push((method.into(), Type::Symbol { id, params: params.clone() }));
85                }
86            }
87        }
88        Ok(id)
89    }
90
91    pub fn get_fn_ref(&mut self, ctx: &mut BuildContext, fn_id: FuncId) -> FuncRef {
92        ctx.get_fn_ref(fn_id).unwrap_or_else(|| {
93            let fn_ref = self.module.declare_func_in_func(fn_id, &mut ctx.builder.func);
94            ctx.fn_refs.push((fn_id, fn_ref));
95            fn_ref
96        })
97    }
98
99    pub fn adjust_args(&mut self, ctx: &mut BuildContext, args: Vec<(Value, Type)>, arg_tys: &[Type]) -> Result<Vec<Value>> {
100        let mut results = Vec::<Value>::new();
101        for ((arg, ty), arg_ty) in args.into_iter().zip(arg_tys.iter()) {
102            if ty != *arg_ty {
103                results.push(self.convert(ctx, (arg, ty), arg_ty.clone())?);
104            } else {
105                results.push(arg);
106            }
107        }
108        Ok(results)
109    }
110
111    pub fn get_fn(&self, id: u32, want_tys: &[Type]) -> Result<FnInfo> {
112        if let Some(fn_info) = self.fns.get(&id) {
113            match fn_info {
114                FnVariant::Compiled(fns) => {
115                    for variant in fns.iter() {
116                        if !variant.generic_args.is_empty() {
117                            continue;
118                        }
119                        if let Type::Fn { tys, ret } = variant.ty.clone() {
120                            if tys.len() != want_tys.len() {
121                                continue;
122                            }
123                            let mut real_types = Vec::new();
124                            for (ty1, ty2) in tys.iter().zip(want_tys.iter()) {
125                                if ty1 != ty2 {
126                                    if ty1.is_any() || ty2.is_any() {
127                                        real_types.push(ty1.clone());
128                                    }
129                                    //ty1 是目的类型
130                                    else {
131                                        break;
132                                    }
133                                } else {
134                                    real_types.push(ty1.clone());
135                                }
136                            }
137                            if real_types.len() == want_tys.len() {
138                                return Ok(FnInfo::Call { fn_id: variant.fn_id, arg_tys: real_types, caps: Vec::new(), ret: ret.as_ref().clone(), context: None });
139                            }
140                        }
141                    }
142                }
143                FnVariant::Inline { fn_ptr, arg_tys } => {
144                    return Ok(FnInfo::Inline { fn_ptr: fn_ptr.clone(), arg_tys: arg_tys.clone() });
145                }
146                FnVariant::Native { ty, fn_id, context } => {
147                    if let Type::Fn { tys, ret } = ty.clone() {
148                        return Ok(FnInfo::Call { fn_id: *fn_id, arg_tys: tys, caps: Vec::new(), ret: ret.as_ref().clone(), context: *context });
149                    }
150                }
151            }
152        }
153        Err(anyhow!("未发现函数 {}", id))
154    }
155
156    pub fn get_sig(&mut self, arg_tys: &[Type], ret: Type) -> Result<Signature> {
157        if let Some(st) = self.sigs.iter().find_map(|s| if s.0 == arg_tys && ret == s.2 { Some(s.1.clone()) } else { None }) {
158            return Ok(st);
159        }
160        let mut sig = self.module.make_signature();
161        for arg in arg_tys.iter() {
162            sig.params.push(AbiParam::new(get_type(arg)?));
163        }
164        if !ret.is_void() {
165            sig.returns.push(AbiParam::new(get_type(&ret)?));
166        }
167        self.sigs.push((arg_tys.to_vec(), sig.clone(), ret.clone()));
168        Ok(sig)
169    }
170
171    fn declare_compiled_fn(&mut self, name_id: Option<&(SmolStr, u32)>, generic_args: &[Type], arg_tys: &[Type], ret_ty: Type) -> Result<FuncId> {
172        let sig = self.get_sig(arg_tys, ret_ty.clone())?;
173        log::debug!("{:?} {:?}", name_id, sig);
174        if let Some((name, id)) = name_id {
175            let variant_idx = match self.fns.get(id) {
176                Some(FnVariant::Compiled(fns)) => fns.len(),
177                _ => 0,
178            };
179            let jit_name = if variant_idx == 0 && generic_args.is_empty() { name.to_string() } else { format!("{name}#{variant_idx}") };
180            let fn_id = self.module.declare_function(&jit_name, Linkage::Local, &sig)?;
181            let variant = CompiledVariant { generic_args: generic_args.to_vec(), ty: Type::Fn { tys: arg_tys.to_vec(), ret: std::rc::Rc::new(ret_ty.clone()) }, fn_id };
182            if let Some(FnVariant::Compiled(fns)) = self.fns.get_mut(id) {
183                fns.push(variant);
184            } else {
185                self.fns.insert(*id, FnVariant::Compiled(vec![variant]));
186            }
187            Ok(fn_id)
188        } else {
189            Ok(self.module.declare_anonymous_function(&sig)?)
190        }
191    }
192
193    fn define_compiled_fn(&mut self, fn_id: FuncId, name_id: Option<&(SmolStr, u32)>, arg_tys: &[Type], ret_ty: Type, local_type_hints: Vec<Option<Type>>, stmt: &Stmt) -> Result<()> {
194        let sig = self.get_sig(arg_tys, ret_ty.clone())?;
195        #[cfg(feature = "ir-disassembly")]
196        let fn_name = name_id.map(|(name, _)| name.clone());
197        let mut ctx = self.module.make_context();
198        ctx.func.signature = sig.clone();
199
200        let mut func_ctx = FunctionBuilderContext::new();
201        let builder = FunctionBuilder::new(&mut ctx.func, &mut func_ctx);
202
203        let mut build_ctx = BuildContext::with_local_type_hints(builder, &arg_tys, ret_ty.clone(), local_type_hints)?;
204        self.scope_enter(&mut build_ctx)?;
205        self.compile_depth += 1;
206        let stmt = Self::coerce_returns(stmt, &ret_ty);
207        let gen_result = self.gen_stmt(&mut build_ctx, &stmt, None, None);
208        self.compile_depth -= 1;
209        gen_result?;
210
211        build_ctx.builder.seal_all_blocks();
212        #[cfg(feature = "ir-disassembly")]
213        {
214            let ir = format!("{}", ctx.func.display());
215            if let Some(name) = fn_name {
216                self.ir_disassembly.insert(name, ir);
217            }
218        }
219        self.module.define_function(fn_id, &mut ctx).with_context(|| name_id.map(|(name, _)| format!("define function {}", name)).unwrap_or_else(|| "define anonymous function".to_string()))?;
220        log::debug!("{:?}", ctx.func);
221        Ok(())
222    }
223
224    pub(crate) fn compile_fn(&mut self, name_id: Option<(SmolStr, u32)>, arg_tys: &[Type], ret_ty: Type, stmt: &Stmt) -> Result<FuncId> {
225        self.compile_fn_with_generic_args(name_id, &[], arg_tys, ret_ty, stmt)
226    }
227
228    pub(crate) fn compile_fn_with_generic_args(&mut self, name_id: Option<(SmolStr, u32)>, generic_args: &[Type], arg_tys: &[Type], ret_ty: Type, stmt: &Stmt) -> Result<FuncId> {
229        self.compile_fn_with_generic_args_and_local_type_hints(name_id, generic_args, arg_tys, ret_ty, Vec::new(), stmt)
230    }
231
232    pub(crate) fn compile_fn_with_generic_args_and_local_type_hints(
233        &mut self,
234        name_id: Option<(SmolStr, u32)>,
235        generic_args: &[Type],
236        arg_tys: &[Type],
237        ret_ty: Type,
238        local_type_hints: Vec<Option<Type>>,
239        stmt: &Stmt,
240    ) -> Result<FuncId> {
241        let drain_pending = self.compile_depth == 0;
242        let fn_id = self.declare_compiled_fn(name_id.as_ref(), generic_args, arg_tys, ret_ty.clone())?;
243        self.define_compiled_fn(fn_id, name_id.as_ref(), arg_tys, ret_ty, local_type_hints, stmt)?;
244        if drain_pending {
245            self.drain_pending_fns()?;
246        }
247        Ok(fn_id)
248    }
249
250    fn drain_pending_fns(&mut self) -> Result<()> {
251        while let Some(pending) = self.pending_fns.pop_front() {
252            let name_id = (pending.name, pending.symbol_id);
253            self.define_compiled_fn(pending.fn_id, Some(&name_id), &pending.arg_tys, pending.ret_ty, pending.local_type_hints, &pending.body)?;
254        }
255        Ok(())
256    }
257
258    pub fn gen_fn(&mut self, ctx: Option<&BuildContext>, id: u32, arg_tys: &[Type]) -> Result<FnInfo> {
259        self.gen_fn_with_params(ctx, id, arg_tys, &[])
260    }
261
262    pub fn gen_fn_with_params(&mut self, ctx: Option<&BuildContext>, id: u32, arg_tys: &[Type], generic_args: &[Type]) -> Result<FnInfo> {
263        self.gen_fn_with_capture_tys(ctx, id, arg_tys, generic_args, None)
264    }
265
266    pub(crate) fn gen_fn_with_capture_tys(&mut self, ctx: Option<&BuildContext>, id: u32, arg_tys: &[Type], generic_args: &[Type], capture_tys: Option<&[Type]>) -> Result<FnInfo> {
267        let mut arg_tys: Vec<Type> = arg_tys.iter().map(|ty| self.compiler.symbols.get_type(ty).unwrap()).collect();
268        if capture_tys.is_none()
269            && generic_args.is_empty()
270            && let Ok(info) = self.get_fn(id, &arg_tys)
271        {
272            return Ok(info);
273        }
274        let (name, s) = self.compiler.symbols.get_symbol(id).map(|(n, s)| (n.clone(), s.clone()))?;
275        if let Symbol::Fn { ty, args, generic_params, cap, body, is_pub: _ } = s.clone() {
276            if let Type::Fn { tys: decl_tys, ret: _ } = ty {
277                let resolved_generic_args = resolve_generic_args_from_types(&generic_params, &decl_tys, &arg_tys, generic_args)?;
278                let generic_args = resolved_generic_args.as_slice();
279                let decl_tys = if generic_params.is_empty() { decl_tys } else { decl_tys.iter().map(|ty| substitute_type(ty, &generic_params, generic_args)).collect() };
280                while arg_tys.len() < decl_tys.len() {
281                    arg_tys.push(self.compiler.symbols.get_type(&decl_tys[arg_tys.len()]).unwrap_or(Type::Any));
282                }
283                let ret_ty = self.compiler.infer_fn_with_params(id, &arg_tys, generic_args)?;
284                let local_type_hints = self.compiler.inferred_local_type_hints(id, generic_args, &arg_tys);
285                if let Some(FnVariant::Compiled(fns)) = self.fns.get(&id) {
286                    for variant in fns {
287                        if variant.generic_args.as_slice() != generic_args {
288                            continue;
289                        }
290                        if let Type::Fn { tys, ret } = &variant.ty
291                            && tys == &arg_tys
292                            && ret.as_ref() == &ret_ty
293                        {
294                            return Ok(FnInfo::Call { fn_id: variant.fn_id, arg_tys: arg_tys.to_vec(), caps: Vec::new(), ret: ret_ty, context: None });
295                        }
296                    }
297                }
298                let mut compile_cap = cap.clone();
299                let body = if generic_params.is_empty() {
300                    body.as_ref().clone()
301                } else {
302                    let mut compile_tys = decl_tys.clone();
303                    let substituted = substitute_stmt(body.as_ref(), &generic_params, generic_args);
304                    let saved_state = self.compiler.take_local_state();
305                    if let Some((module, _)) = name.split_once("::") {
306                        self.compiler.symbols.push_module_scope(module.into());
307                    }
308                    let compiled_body = self.compiler.compile_fn(&args, &mut compile_tys, substituted, &mut compile_cap);
309                    if name.contains("::") {
310                        self.compiler.symbols.pop_module_scope();
311                    }
312                    self.compiler.restore_local_state(saved_state);
313                    Stmt::new(StmtKind::Block(compiled_body?), Span::default())
314                };
315                if let Some(capture_tys) = capture_tys {
316                    if capture_tys.len() != compile_cap.vars.len() {
317                        return Err(anyhow!("capture type count mismatch: got {}, want {}", capture_tys.len(), compile_cap.vars.len()));
318                    }
319                    arg_tys.extend_from_slice(capture_tys);
320                } else {
321                    for v in compile_cap.vars.iter() {
322                        ctx.as_ref().map(|ctx| arg_tys.push(ctx.vars[*v].get_ty()));
323                    }
324                }
325                let fn_id = if self.compile_depth > 0 {
326                    let fn_id = self.declare_compiled_fn(Some(&(name.clone(), id)), generic_args, &arg_tys, ret_ty.clone())?;
327                    self.pending_fns.push_back(PendingFn { name: name.clone(), symbol_id: id, fn_id, arg_tys: arg_tys.clone(), ret_ty: ret_ty.clone(), local_type_hints, body });
328                    fn_id
329                } else {
330                    let fn_id = self.compile_fn_with_generic_args_and_local_type_hints(Some((name.clone(), id)), generic_args, &arg_tys, ret_ty.clone(), local_type_hints, &body)?;
331                    self.drain_pending_fns()?;
332                    self.module.finalize_definitions()?;
333                    fn_id
334                };
335                return Ok(FnInfo::Call { fn_id, arg_tys: arg_tys.to_vec(), caps: compile_cap.vars.clone(), ret: ret_ty, context: None });
336            }
337            let ret_ty = self.compiler.infer_fn_with_params(id, &arg_tys, generic_args)?;
338            let local_type_hints = self.compiler.inferred_local_type_hints(id, generic_args, &arg_tys);
339            if let Some(capture_tys) = capture_tys {
340                if capture_tys.len() != cap.vars.len() {
341                    return Err(anyhow!("capture type count mismatch: got {}, want {}", capture_tys.len(), cap.vars.len()));
342                }
343                arg_tys.extend_from_slice(capture_tys);
344            } else {
345                for v in cap.vars.iter() {
346                    ctx.as_ref().map(|ctx| arg_tys.push(ctx.vars[*v].get_ty()));
347                }
348            }
349            let body = if generic_params.is_empty() { body.as_ref().clone() } else { substitute_stmt(body.as_ref(), &generic_params, generic_args) };
350            let fn_id = if self.compile_depth > 0 {
351                let fn_id = self.declare_compiled_fn(Some(&(name.clone(), id)), generic_args, &arg_tys, ret_ty.clone())?;
352                self.pending_fns.push_back(PendingFn { name: name.clone(), symbol_id: id, fn_id, arg_tys: arg_tys.clone(), ret_ty: ret_ty.clone(), local_type_hints, body });
353                fn_id
354            } else {
355                let fn_id = self.compile_fn_with_generic_args_and_local_type_hints(Some((name.clone(), id)), generic_args, &arg_tys, ret_ty.clone(), local_type_hints, &body)?;
356                self.drain_pending_fns()?;
357                self.module.finalize_definitions()?;
358                fn_id
359            };
360            return Ok(FnInfo::Call { fn_id, arg_tys: arg_tys.to_vec(), caps: cap.vars.clone(), ret: ret_ty, context: None });
361        }
362        Err(anyhow!("生成函数 {}({}) 失败: symbol 不是函数: {:?}", id, name, s))
363    }
364}