Skip to main content

vm/
lib.rs

1//使用 cranelift 作为后端 直接 jit 解释脚本
2mod binary;
3mod native;
4pub use native::{ANY, STD};
5
6mod fns;
7use anyhow::{Result, anyhow};
8pub use fns::{FnInfo, FnVariant};
9mod context;
10pub use context::BuildContext;
11
12mod rt;
13use cranelift::prelude::types;
14use dynamic::Type;
15pub use rt::JITRunTime;
16use smol_str::SmolStr;
17mod http_module;
18mod llm_module;
19mod root_module;
20
21use std::cell::RefCell;
22use std::sync::{Mutex, OnceLock, Weak};
23static PTR_TYPE: OnceLock<types::Type> = OnceLock::new();
24pub fn ptr_type() -> types::Type {
25    PTR_TYPE.get().cloned().unwrap()
26}
27
28pub fn get_type(ty: &Type) -> Result<types::Type> {
29    if ty.is_f64() {
30        Ok(types::F64)
31    } else if ty.is_f32() {
32        Ok(types::F32)
33    } else if ty.is_int() | ty.is_uint() {
34        match ty.width() {
35            1 => Ok(types::I8),
36            2 => Ok(types::I16),
37            4 => Ok(types::I32),
38            8 => Ok(types::I64),
39            _ => Err(anyhow!("非法类型 {:?}", ty)),
40        }
41    } else if let Type::Bool = ty {
42        Ok(types::I8)
43    } else {
44        Ok(ptr_type())
45    }
46}
47
48use compiler::Symbol;
49use cranelift::prelude::*;
50
51pub fn init_jit(mut jit: JITRunTime) -> Result<JITRunTime> {
52    jit.add_all()?;
53    Ok(jit)
54}
55
56use std::sync::Arc;
57unsafe impl Send for JITRunTime {}
58unsafe impl Sync for JITRunTime {}
59
60thread_local! {
61    static CURRENT_VM: RefCell<Option<Weak<Mutex<JITRunTime>>>> = const { RefCell::new(None) };
62}
63
64fn set_current_vm(vm: &Vm) {
65    CURRENT_VM.with(|current| {
66        *current.borrow_mut() = Some(Arc::downgrade(&vm.jit));
67    });
68}
69
70fn with_current_vm<T>(f: impl FnOnce(&Vm) -> Result<T>) -> Result<T> {
71    CURRENT_VM.with(|current| {
72        let jit = current.borrow().as_ref().and_then(Weak::upgrade).ok_or_else(|| anyhow!("当前线程没有 VM"))?;
73        let vm = Vm { jit };
74        f(&vm)
75    })
76}
77
78pub(crate) fn import_current(name: &str, path: &str) -> Result<()> {
79    with_current_vm(|vm| vm.import(name, path))
80}
81
82pub(crate) fn get_current_fn_ptr(name: &str, arg_tys: &[Type]) -> Result<(*const u8, Type)> {
83    with_current_vm(|vm| vm.get_fn_ptr(name, arg_tys))
84}
85
86fn add_method_field(jit: &mut JITRunTime, def: &str, method: &str, id: u32) -> Result<()> {
87    let def_id = jit.get_id(def)?;
88    if let Some((_, define)) = jit.compiler.symbols.get_symbol_mut(def_id) {
89        if let Symbol::Struct(Type::Struct { params, fields }, _) = define {
90            fields.push((method.into(), Type::Symbol { id, params: params.clone() }));
91        }
92    }
93    Ok(())
94}
95
96fn add_native_module_fns(jit: &mut JITRunTime, module: &str, fns: &[(&str, &[Type], Type, *const u8)]) -> Result<()> {
97    jit.add_module(module);
98    for (name, arg_tys, ret_ty, fn_ptr) in fns {
99        let full_name = format!("{}::{}", module, name);
100        jit.add_native_ptr(&full_name, name, arg_tys, ret_ty.clone(), *fn_ptr)?;
101    }
102    jit.pop_module();
103    Ok(())
104}
105
106impl JITRunTime {
107    pub fn add_module(&mut self, name: &str) {
108        self.compiler.symbols.add_module(name.into());
109    }
110
111    pub fn pop_module(&mut self) {
112        self.compiler.symbols.pop_module();
113    }
114
115    pub fn add_type(&mut self, name: &str, ty: Type, is_pub: bool) -> u32 {
116        self.compiler.add_symbol(name, Symbol::Struct(ty, is_pub))
117    }
118
119    pub fn add_empty_type(&mut self, name: &str) -> Result<u32> {
120        match self.get_id(name) {
121            Ok(id) => Ok(id),
122            Err(_) => Ok(self.add_type(name, Type::Struct { params: Vec::new(), fields: Vec::new() }, true)),
123        }
124    }
125
126    pub fn add_native_module_ptr(&mut self, module: &str, name: &str, arg_tys: &[Type], ret_ty: Type, fn_ptr: *const u8) -> Result<u32> {
127        self.add_module(module);
128        let full_name = format!("{}::{}", module, name);
129        let result = self.add_native_ptr(&full_name, name, arg_tys, ret_ty, fn_ptr);
130        self.pop_module();
131        result
132    }
133
134    pub fn add_native_method_ptr(&mut self, def: &str, method: &str, arg_tys: &[Type], ret_ty: Type, fn_ptr: *const u8) -> Result<u32> {
135        self.add_empty_type(def)?;
136        let full_name = format!("{}::{}", def, method);
137        let id = self.add_native_ptr(&full_name, &full_name, arg_tys, ret_ty, fn_ptr)?;
138        add_method_field(self, def, method, id)?;
139        Ok(id)
140    }
141
142    pub fn add_std(&mut self) -> Result<()> {
143        self.add_module("std");
144        for (name, arg_tys, ret_ty, fn_ptr) in STD {
145            self.add_native_ptr(name, name, arg_tys, ret_ty, fn_ptr)?;
146        }
147        Ok(())
148    }
149
150    pub fn add_any(&mut self) -> Result<()> {
151        for (name, arg_tys, ret_ty, fn_ptr) in ANY {
152            let (_, method) = name.split_once("::").ok_or_else(|| anyhow!("非法 Any 方法名 {}", name))?;
153            self.add_native_method_ptr("Any", method, arg_tys, ret_ty, fn_ptr)?;
154        }
155        Ok(())
156    }
157
158    pub fn add_vec(&mut self) -> Result<()> {
159        self.add_empty_type("Vec")?;
160        let vec_def = Type::Symbol { id: self.get_id("Vec")?, params: Vec::new() };
161        self.add_inline("Vec::swap", vec![vec_def.clone(), Type::I64, Type::I64], Type::Void, |ctx: Option<&mut BuildContext>, args: Vec<Value>| {
162            if let Some(ctx) = ctx {
163                let width = ctx.builder.ins().iconst(types::I64, 4);
164                let offset_val = ctx.builder.ins().imul(args[1], width); // i * 4 i32大小四字节
165                let final_addr = ctx.builder.ins().iadd(args[0], offset_val); // base + (i*4)
166                let dest = ctx.builder.ins().imul(args[2], width);
167                let dest_addr = ctx.builder.ins().iadd(args[0], dest); // base + (i*4)
168                let dest_val = ctx.builder.ins().load(types::I32, MemFlags::trusted(), dest_addr, 0);
169                let v = ctx.builder.ins().load(types::I32, MemFlags::trusted(), final_addr, 0);
170                ctx.builder.ins().store(MemFlags::trusted(), v, dest_addr, 0);
171                ctx.builder.ins().store(MemFlags::trusted(), dest_val, final_addr, 0);
172            }
173            Err(anyhow!("无返回值"))
174        })?;
175
176        self.add_inline("Vec::get_idx", vec![vec_def.clone(), Type::I64], Type::I32, |ctx: Option<&mut BuildContext>, args: Vec<Value>| {
177            if let Some(ctx) = ctx {
178                let width = ctx.builder.ins().iconst(types::I64, 4);
179                let offset_val = ctx.builder.ins().imul(args[1], width); // i * 4 i32大小四字节
180                let final_addr = ctx.builder.ins().iadd(args[0], offset_val);
181                Ok((Some(ctx.builder.ins().load(types::I32, MemFlags::trusted(), final_addr, 0)), Type::I32))
182            } else {
183                Ok((None, Type::I32))
184            }
185        })?;
186        Ok(())
187    }
188
189    pub fn add_llm(&mut self) -> Result<()> {
190        add_native_module_fns(self, "llm", &llm_module::LLM_NATIVE)
191    }
192
193    pub fn add_root(&mut self) -> Result<()> {
194        add_native_module_fns(self, "root", &root_module::ROOT_NATIVE)
195    }
196
197    pub fn add_http(&mut self) -> Result<()> {
198        add_native_module_fns(self, "http", &http_module::HTTP_NATIVE)
199    }
200
201    pub fn add_all(&mut self) -> Result<()> {
202        self.add_std()?;
203        self.add_any()?;
204        self.add_vec()?;
205        self.add_llm()?;
206        self.add_root()?;
207        self.add_http()?;
208        Ok(())
209    }
210}
211
212#[derive(Clone)]
213pub struct Vm {
214    jit: Arc<Mutex<JITRunTime>>,
215}
216
217#[derive(Clone)]
218pub struct CompiledFn {
219    ptr: usize,
220    ret: Type,
221    owner: Vm,
222}
223
224impl CompiledFn {
225    pub fn ptr(&self) -> *const u8 {
226        set_current_vm(&self.owner);
227        self.ptr as *const u8
228    }
229
230    pub fn ret_ty(&self) -> &Type {
231        &self.ret
232    }
233
234    pub fn owner(&self) -> &Vm {
235        &self.owner
236    }
237}
238
239impl Vm {
240    pub fn new() -> Self {
241        Self { jit: Arc::new(Mutex::new(JITRunTime::new(|_| {}))) }
242    }
243
244    pub fn with_all() -> Result<Self> {
245        let vm = Self::new();
246        vm.add_all()?;
247        Ok(vm)
248    }
249
250    pub fn add_module(&self, name: &str) {
251        self.jit.lock().unwrap().add_module(name)
252    }
253
254    pub fn pop_module(&self) {
255        self.jit.lock().unwrap().pop_module()
256    }
257
258    pub fn add_type(&self, name: &str, ty: Type, is_pub: bool) -> u32 {
259        self.jit.lock().unwrap().add_type(name, ty, is_pub)
260    }
261
262    pub fn add_empty_type(&self, name: &str) -> Result<u32> {
263        self.jit.lock().unwrap().add_empty_type(name)
264    }
265
266    pub fn add_std(&self) -> Result<()> {
267        self.jit.lock().unwrap().add_std()
268    }
269
270    pub fn add_any(&self) -> Result<()> {
271        self.jit.lock().unwrap().add_any()
272    }
273
274    pub fn add_vec(&self) -> Result<()> {
275        self.jit.lock().unwrap().add_vec()
276    }
277
278    pub fn add_llm(&self) -> Result<()> {
279        self.jit.lock().unwrap().add_llm()
280    }
281
282    pub fn add_root(&self) -> Result<()> {
283        self.jit.lock().unwrap().add_root()
284    }
285
286    pub fn add_http(&self) -> Result<()> {
287        self.jit.lock().unwrap().add_http()
288    }
289
290    pub fn add_all(&self) -> Result<()> {
291        self.jit.lock().unwrap().add_all()
292    }
293
294    pub fn add_native_ptr(&self, full_name: &str, name: &str, arg_tys: &[Type], ret_ty: Type, fn_ptr: *const u8) -> Result<u32> {
295        self.jit.lock().unwrap().add_native_ptr(full_name, name, arg_tys, ret_ty, fn_ptr)
296    }
297
298    pub fn add_native_module_ptr(&self, module: &str, name: &str, arg_tys: &[Type], ret_ty: Type, fn_ptr: *const u8) -> Result<u32> {
299        self.jit.lock().unwrap().add_native_module_ptr(module, name, arg_tys, ret_ty, fn_ptr)
300    }
301
302    pub fn add_native_method_ptr(&self, def: &str, method: &str, arg_tys: &[Type], ret_ty: Type, fn_ptr: *const u8) -> Result<u32> {
303        self.jit.lock().unwrap().add_native_method_ptr(def, method, arg_tys, ret_ty, fn_ptr)
304    }
305
306    pub fn add_inline(&self, name: &str, args: Vec<Type>, ret: Type, f: fn(Option<&mut BuildContext>, Vec<Value>) -> Result<(Option<Value>, Type)>) -> Result<u32> {
307        self.jit.lock().unwrap().add_inline(name, args, ret, f)
308    }
309
310    pub fn import_code(&self, name: &str, code: Vec<u8>) -> Result<()> {
311        self.jit.lock().unwrap().import_code(name, code)
312    }
313
314    pub fn import_file(&self, name: &str, path: &str) -> Result<()> {
315        self.jit.lock().unwrap().compiler.import_file(name, path)?;
316        Ok(())
317    }
318
319    pub fn import(&self, name: &str, path: &str) -> Result<()> {
320        if root::contains(path) {
321            let code = root::get(path).unwrap();
322            if code.is_str() {
323                self.import_code(name, code.as_str().as_bytes().to_vec())
324            } else {
325                self.import_code(name, code.get_dynamic("code").ok_or(anyhow!("{:?} 没有 code 成员", code))?.as_str().as_bytes().to_vec())
326            }
327        } else {
328            self.import_file(name, path)
329        }
330    }
331
332    pub fn infer(&self, name: &str, arg_tys: &[Type]) -> Result<Type> {
333        self.jit.lock().unwrap().get_type(name, arg_tys)
334    }
335
336    pub fn get_fn_ptr(&self, name: &str, arg_tys: &[Type]) -> Result<(*const u8, Type)> {
337        self.jit.lock().unwrap().get_fn_ptr(name, arg_tys)
338    }
339
340    pub fn get_fn(&self, name: &str, arg_tys: &[Type]) -> Result<CompiledFn> {
341        set_current_vm(self);
342        let (ptr, ret) = self.get_fn_ptr(name, arg_tys)?;
343        Ok(CompiledFn { ptr: ptr as usize, ret, owner: self.clone() })
344    }
345
346    pub fn load(&self, code: Vec<u8>, arg_name: SmolStr) -> Result<(i64, Type)> {
347        self.jit.lock().unwrap().load(code, arg_name)
348    }
349
350    pub fn get_symbol(&self, name: &str, params: Vec<Type>) -> Result<Type> {
351        Ok(Type::Symbol { id: self.jit.lock().unwrap().get_id(name)?, params })
352    }
353
354    pub fn disassemble(&self, name: &str) -> Result<String> {
355        self.jit.lock().unwrap().compiler.symbols.disassemble(name)
356    }
357
358    #[cfg(feature = "ir-disassembly")]
359    pub fn disassemble_ir(&self, name: &str) -> Result<String> {
360        self.jit.lock().unwrap().disassemble_ir(name)
361    }
362}
363
364impl Default for Vm {
365    fn default() -> Self {
366        Self::new()
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::Vm;
373    use dynamic::{Dynamic, ToJson, Type};
374
375    extern "C" fn math_double(value: i64) -> i64 {
376        value * 2
377    }
378
379    #[test]
380    fn vm_can_add_native_after_jit_creation() -> anyhow::Result<()> {
381        let vm = Vm::new();
382        vm.add_native_module_ptr("math", "double", &[Type::I64], Type::I64, math_double as *const u8)?;
383        vm.import_code(
384            "vm_dynamic_native",
385            br#"
386            pub fn run(value: i64) {
387                math::double(value)
388            }
389            "#
390            .to_vec(),
391        )?;
392
393        let compiled = vm.get_fn("vm_dynamic_native::run", &[Type::I64])?;
394        assert_eq!(compiled.ret_ty(), &Type::I64);
395        let run: extern "C" fn(i64) -> i64 = unsafe { std::mem::transmute(compiled.ptr()) };
396        assert_eq!(run(21), 42);
397        Ok(())
398    }
399
400    #[test]
401    fn compares_any_with_string_literal_as_string() -> anyhow::Result<()> {
402        let vm = Vm::with_all()?;
403        vm.import_code(
404            "vm_string_compare_any",
405            br#"
406            pub fn any_ne_empty(chat_path) {
407                chat_path != ""
408            }
409            "#
410            .to_vec(),
411        )?;
412
413        let compiled = vm.get_fn("vm_string_compare_any::any_ne_empty", &[Type::Any])?;
414        assert_eq!(compiled.ret_ty(), &Type::Bool);
415
416        let any_ne_empty: extern "C" fn(*const Dynamic) -> bool = unsafe { std::mem::transmute(compiled.ptr()) };
417        let empty = Dynamic::from("");
418        let non_empty = Dynamic::from("chat");
419
420        assert!(!any_ne_empty(&empty));
421        assert!(any_ne_empty(&non_empty));
422        Ok(())
423    }
424
425    #[test]
426    fn compares_concrete_value_with_string_literal_as_string() -> anyhow::Result<()> {
427        let vm = Vm::with_all()?;
428        vm.import_code(
429            "vm_string_compare_imm",
430            br#"
431            pub fn int_eq_str(value: i64) {
432                value == "42"
433            }
434
435            pub fn int_to_str(value: i64) {
436                value + ""
437            }
438            "#
439            .to_vec(),
440        )?;
441
442        let compiled = vm.get_fn("vm_string_compare_imm::int_eq_str", &[Type::I64])?;
443        assert_eq!(compiled.ret_ty(), &Type::Bool);
444
445        let int_eq_str: extern "C" fn(i64) -> bool = unsafe { std::mem::transmute(compiled.ptr()) };
446
447        let compiled = vm.get_fn("vm_string_compare_imm::int_to_str", &[Type::I64])?;
448        assert_eq!(compiled.ret_ty(), &Type::Any);
449        let int_to_str: extern "C" fn(i64) -> *const Dynamic = unsafe { std::mem::transmute(compiled.ptr()) };
450        let text = int_to_str(42);
451        assert_eq!(unsafe { &*text }.as_str(), "42");
452
453        assert!(int_eq_str(42));
454        assert!(!int_eq_str(7));
455        Ok(())
456    }
457
458    #[test]
459    fn dynamic_field_value_participates_in_or_expression() -> anyhow::Result<()> {
460        import_code(
461            "vm_dynamic_field_or",
462            r#"
463            pub fn next_or_start() {
464                let choice = {
465                    label: "颜色",
466                    next: "color"
467                };
468                choice.next || "start"
469            }
470
471            pub fn direct_next() {
472                let choice = {
473                    label: "颜色",
474                    next: "color"
475                };
476                choice.next
477            }
478
479            pub fn bracket_next() {
480                let choice = {
481                    label: "颜色",
482                    next: "color"
483                };
484                choice["next"]
485            }
486
487            pub fn assigned_preview() {
488                let choice = {
489                    next: "tax_free"
490                };
491                choice.preview = choice.next || "start";
492                choice
493            }
494            "#
495            .as_bytes()
496            .to_vec(),
497        )?;
498
499        let (fn_ptr, ret_ty) = get_fn("vm_dynamic_field_or::direct_next", &[])?;
500        assert_eq!(ret_ty, Type::Any);
501        let direct_next: extern "C" fn() -> *const Dynamic = unsafe { std::mem::transmute(fn_ptr) };
502        assert_eq!(unsafe { &*direct_next() }.as_str(), "color");
503
504        let (fn_ptr, ret_ty) = get_fn("vm_dynamic_field_or::bracket_next", &[])?;
505        assert_eq!(ret_ty, Type::Any);
506        let bracket_next: extern "C" fn() -> *const Dynamic = unsafe { std::mem::transmute(fn_ptr) };
507        assert_eq!(unsafe { &*bracket_next() }.as_str(), "color");
508
509        let (fn_ptr, ret_ty) = get_fn("vm_dynamic_field_or::next_or_start", &[])?;
510        assert_eq!(ret_ty, Type::Any);
511        let next_or_start: extern "C" fn() -> *const Dynamic = unsafe { std::mem::transmute(fn_ptr) };
512        assert_eq!(unsafe { &*next_or_start() }.as_str(), "color");
513
514        let (fn_ptr, ret_ty) = get_fn("vm_dynamic_field_or::assigned_preview", &[])?;
515        assert_eq!(ret_ty, Type::Any);
516        let assigned_preview: extern "C" fn() -> *const Dynamic = unsafe { std::mem::transmute(fn_ptr) };
517        let choice = unsafe { &*assigned_preview() };
518        assert_eq!(choice.get_dynamic("preview").unwrap().as_str(), "tax_free");
519        Ok(())
520    }
521
522    #[test]
523    fn root_native_calls_do_not_take_ownership_of_dynamic_args() -> anyhow::Result<()> {
524        let vm = Vm::with_all()?;
525        vm.import_code(
526            "vm_root_clone_bridge",
527            br#"
528            pub fn add_then_reuse(arg) {
529                let user = {
530                    address: "test-wallet",
531                    points: 20
532                };
533                root::add("local/root-clone-bridge-user", user);
534                user.points = user.points - 7;
535                root::add("local/root-clone-bridge-user", user);
536                {
537                    user: user,
538                    points: user.points
539                }
540            }
541            "#
542            .to_vec(),
543        )?;
544
545        let compiled = vm.get_fn("vm_root_clone_bridge::add_then_reuse", &[Type::Any])?;
546        assert_eq!(compiled.ret_ty(), &Type::Any);
547        let add_then_reuse: extern "C" fn(*const Dynamic) -> *const Dynamic = unsafe { std::mem::transmute(compiled.ptr()) };
548        let arg = Dynamic::Null;
549        let result = add_then_reuse(&arg);
550        let result = unsafe { &*result };
551
552        assert_eq!(result.get_dynamic("points").and_then(|value| value.as_int()), Some(13));
553        let mut json = String::new();
554        result.to_json(&mut json);
555        assert!(json.contains("\"points\": 13"));
556        Ok(())
557    }
558}