Skip to main content

vm/
lib.rs

1//使用 cranelift 作为后端 直接 jit 解释脚本
2mod binary;
3mod native;
4pub use native::{ANY, STD};
5
6mod fns;
7use anyhow::{Result, anyhow};
8pub use fns::{FnInfo, FnVariant};
9mod context;
10use context::BuildContext;
11
12mod rt;
13use cranelift::prelude::types;
14use dynamic::Type;
15pub use rt::JITRunTime;
16use smol_str::SmolStr;
17mod llm_module;
18mod root_module;
19
20use std::sync::{OnceLock, RwLock};
21static PTR_TYPE: OnceLock<types::Type> = OnceLock::new();
22pub fn ptr_type() -> types::Type {
23    PTR_TYPE.get().cloned().unwrap()
24}
25
26pub fn get_type(ty: &Type) -> Result<types::Type> {
27    if ty.is_f64() {
28        Ok(types::F64)
29    } else if ty.is_f32() {
30        Ok(types::F32)
31    } else if ty.is_int() | ty.is_uint() {
32        match ty.width() {
33            1 => Ok(types::I8),
34            2 => Ok(types::I16),
35            4 => Ok(types::I32),
36            8 => Ok(types::I64),
37            _ => Err(anyhow!("非法类型 {:?}", ty)),
38        }
39    } else if let Type::Bool = ty {
40        Ok(types::I8)
41    } else {
42        Ok(ptr_type())
43    }
44}
45
46use compiler::Symbol;
47use cranelift::prelude::*;
48
49pub fn init_jit(mut jit: JITRunTime) -> Result<JITRunTime> {
50    jit.compiler.symbols.add_module("std".into()); //开始导入标准库,可以直接访问
51    for std in STD {
52        jit.add_native(std.0, std.0, std.1, std.2)?;
53    }
54
55    let mut fields = Vec::new();
56    for (name, arg_tys, ret_ty, _) in ANY {
57        let id = jit.add_native(name, name, arg_tys, ret_ty)?;
58        let (_, field_name) = name.split_once("::").unwrap();
59        fields.push((field_name.into(), Type::Symbol { id, params: Vec::new() }));
60    }
61    jit.compiler.add_symbol("Any", Symbol::Struct(Type::Struct { params: Vec::new(), fields }, true));
62
63    jit.compiler.add_symbol("Vec", Symbol::Struct(Type::Struct { params: Vec::new(), fields: Vec::new() }, true));
64    let vec_def = Type::Symbol { id: jit.get_id("Vec")?, params: Vec::new() };
65    jit.add_inline("Vec::swap", vec![vec_def.clone(), Type::I64, Type::I64], Type::Void, |ctx: Option<&mut BuildContext>, args: Vec<Value>| {
66        if let Some(ctx) = ctx {
67            let width = ctx.builder.ins().iconst(types::I64, 4);
68            let offset_val = ctx.builder.ins().imul(args[1], width); // i * 4 i32大小四字节
69            let final_addr = ctx.builder.ins().iadd(args[0], offset_val); // base + (i*4)
70            let dest = ctx.builder.ins().imul(args[2], width);
71            let dest_addr = ctx.builder.ins().iadd(args[0], dest); // base + (i*4)
72            let dest_val = ctx.builder.ins().load(types::I32, MemFlags::trusted(), dest_addr, 0);
73            let v = ctx.builder.ins().load(types::I32, MemFlags::trusted(), final_addr, 0);
74            ctx.builder.ins().store(MemFlags::trusted(), v, dest_addr, 0);
75            ctx.builder.ins().store(MemFlags::trusted(), dest_val, final_addr, 0);
76        }
77        Err(anyhow!("无返回值"))
78    })?;
79
80    jit.add_inline("Vec::get_idx", vec![vec_def.clone(), Type::I64], Type::I32, |ctx: Option<&mut BuildContext>, args: Vec<Value>| {
81        if let Some(ctx) = ctx {
82            let width = ctx.builder.ins().iconst(types::I64, 4);
83            let offset_val = ctx.builder.ins().imul(args[1], width); // i * 4 i32大小四字节
84            let final_addr = ctx.builder.ins().iadd(args[0], offset_val);
85            Ok((Some(ctx.builder.ins().load(types::I32, MemFlags::trusted(), final_addr, 0)), Type::I32))
86        } else {
87            Ok((None, Type::I32))
88        }
89    })?;
90    Ok(jit)
91}
92
93use std::sync::Arc;
94
95use std::sync::LazyLock;
96unsafe impl Send for JITRunTime {}
97unsafe impl Sync for JITRunTime {}
98
99//直接在这里增加一行 就可以导入一个模块
100static mut MODULES: &[(&str, &[(&str, &[Type], Type, *const u8)])] = &[("llm", &llm_module::LLM_NATIVE), ("root", &root_module::ROOT_NATIVE)];
101
102pub static JIT: LazyLock<Arc<RwLock<JITRunTime>>> = LazyLock::new(|| {
103    let jit = JITRunTime::new(|b| {
104        //这里注册所有的外部符号
105        for (name, _, _, fn_ptr) in STD {
106            b.symbol(name, fn_ptr);
107        }
108        for (name, _, _, fn_ptr) in ANY {
109            b.symbol(name, fn_ptr);
110        }
111        for (name, fns) in unsafe { MODULES.into_iter() } {
112            for (fn_name, _, _, fn_ptr) in *fns {
113                let full_name = format!("{}::{}", *name, *fn_name);
114                b.symbol(&full_name, *fn_ptr);
115            }
116        }
117    });
118    let mut jit = init_jit(jit).unwrap();
119    for (name, fns) in unsafe { MODULES.into_iter() } {
120        jit.compiler.symbols.add_module((*name).into());
121        for r in fns.into_iter() {
122            let full_name = format!("{}::{}", *name, r.0);
123            jit.add_native(&&full_name, r.0, r.1, r.2.clone()).unwrap();
124        }
125        jit.compiler.symbols.pop_module();
126    }
127    Arc::new(RwLock::new(jit))
128});
129
130pub fn import_code(name: &str, code: Vec<u8>) -> Result<()> {
131    JIT.write().unwrap().import_code(name, code)
132}
133
134pub fn import(name: &str, path: &str) -> Result<()> {
135    if root::contains(path) {
136        //优先从 root 文件系统里面 import
137        let code = root::get(path).unwrap();
138        if code.is_str() {
139            JIT.write().unwrap().import_code(name, code.as_str().as_bytes().to_vec())
140        } else {
141            JIT.write().unwrap().import_code(name, code.get_dynamic("code").ok_or(anyhow!("{:?} 没有 code 成员", code))?.as_str().as_bytes().to_vec())
142        }
143    } else {
144        JIT.write().unwrap().compiler.import_file(name, path)?;
145        Ok(())
146    }
147}
148
149pub fn infer(name: &str, arg_tys: &[Type]) -> Result<Type> {
150    JIT.write().unwrap().get_type(name, arg_tys)
151}
152
153pub fn get_fn(name: &str, arg_tys: &[Type]) -> Result<(*const u8, Type)> {
154    JIT.write().unwrap().get_fn_ptr(name, arg_tys)
155}
156
157pub fn load(code: Vec<u8>, arg_name: SmolStr) -> Result<(i64, Type)> {
158    JIT.write().unwrap().load(code, arg_name)
159}
160
161pub fn get_symbol(name: &str, params: Vec<Type>) -> Result<Type> {
162    Ok(Type::Symbol { id: JIT.read().unwrap().get_id(name)?, params })
163}
164
165pub fn disassemble(name: &str) -> Result<String> {
166    JIT.read().unwrap().compiler.symbols.disassemble(name)
167}
168
169#[cfg(feature = "ir-disassembly")]
170pub fn disassemble_ir(name: &str) -> Result<String> {
171    JIT.write().unwrap().disassemble_ir(name)
172}
173
174#[cfg(test)]
175mod tests {
176    use super::{get_fn, import_code};
177    use dynamic::{Dynamic, Type};
178
179    #[test]
180    fn compares_any_with_string_literal_as_string() -> anyhow::Result<()> {
181        import_code(
182            "vm_string_compare_any",
183            br#"
184            pub fn any_ne_empty(chat_path) {
185                chat_path != ""
186            }
187            "#
188            .to_vec(),
189        )?;
190
191        let (fn_ptr, ret_ty) = get_fn("vm_string_compare_any::any_ne_empty", &[Type::Any])?;
192        assert_eq!(ret_ty, Type::Bool);
193
194        let any_ne_empty: extern "C" fn(*const Dynamic) -> bool = unsafe { std::mem::transmute(fn_ptr) };
195        let empty = Dynamic::from("");
196        let non_empty = Dynamic::from("chat");
197
198        assert!(!any_ne_empty(&empty));
199        assert!(any_ne_empty(&non_empty));
200        Ok(())
201    }
202
203    #[test]
204    fn compares_concrete_value_with_string_literal_as_string() -> anyhow::Result<()> {
205        import_code(
206            "vm_string_compare_imm",
207            br#"
208            pub fn int_eq_str(value: i64) {
209                value == "42"
210            }
211
212            pub fn int_to_str(value: i64) {
213                value + ""
214            }
215            "#
216            .to_vec(),
217        )?;
218
219        let (fn_ptr, ret_ty) = get_fn("vm_string_compare_imm::int_eq_str", &[Type::I64])?;
220        assert_eq!(ret_ty, Type::Bool);
221
222        let int_eq_str: extern "C" fn(i64) -> bool = unsafe { std::mem::transmute(fn_ptr) };
223
224        let (fn_ptr, ret_ty) = get_fn("vm_string_compare_imm::int_to_str", &[Type::I64])?;
225        assert_eq!(ret_ty, Type::Any);
226        let int_to_str: extern "C" fn(i64) -> *const Dynamic = unsafe { std::mem::transmute(fn_ptr) };
227        let text = int_to_str(42);
228        assert_eq!(unsafe { &*text }.as_str(), "42");
229
230        assert!(int_eq_str(42));
231        assert!(!int_eq_str(7));
232        Ok(())
233    }
234}