1use shape_abi_v1::{ErrorModel, LanguageRuntimeLspConfig, LanguageRuntimeVTable};
7use shape_ast::error::{Result, ShapeError};
8use std::ffi::{CStr, c_void};
9use std::sync::Arc;
10
11#[derive(Clone)]
13pub struct CompiledForeignFunction {
14 handle: *mut c_void,
15 _runtime: Arc<LanguageRuntimeState>,
17}
18
19unsafe impl Send for CompiledForeignFunction {}
22unsafe impl Sync for CompiledForeignFunction {}
23
24struct LanguageRuntimeState {
25 vtable: &'static LanguageRuntimeVTable,
26 instance: *mut c_void,
27}
28
29unsafe impl Send for LanguageRuntimeState {}
31unsafe impl Sync for LanguageRuntimeState {}
32
33impl Drop for LanguageRuntimeState {
34 fn drop(&mut self) {
35 if let Some(drop_fn) = self.vtable.drop {
36 unsafe { drop_fn(self.instance) };
37 }
38 }
39}
40
41pub struct PluginLanguageRuntime {
43 language_id: String,
45 state: Arc<LanguageRuntimeState>,
47 error_model: ErrorModel,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct RuntimeLspConfig {
54 pub language_id: String,
55 pub server_command: Vec<String>,
56 pub file_extension: String,
57 pub extra_paths: Vec<String>,
58}
59
60impl PluginLanguageRuntime {
61 pub fn new(vtable: &'static LanguageRuntimeVTable, config: &serde_json::Value) -> Result<Self> {
63 let config_bytes = rmp_serde::to_vec(config).map_err(|e| ShapeError::RuntimeError {
64 message: format!("Failed to serialize language runtime config: {}", e),
65 location: None,
66 })?;
67
68 let init_fn = vtable.init.ok_or_else(|| ShapeError::RuntimeError {
69 message: "Language runtime vtable has no init function".to_string(),
70 location: None,
71 })?;
72
73 let instance = unsafe { init_fn(config_bytes.as_ptr(), config_bytes.len()) };
74 if instance.is_null() {
75 return Err(ShapeError::RuntimeError {
76 message: "Language runtime init returned null".to_string(),
77 location: None,
78 });
79 }
80
81 let lang_id_fn = vtable.language_id.ok_or_else(|| ShapeError::RuntimeError {
83 message: "Language runtime vtable has no language_id function".to_string(),
84 location: None,
85 })?;
86 let lang_ptr = unsafe { lang_id_fn(instance) };
87 let language_id = if lang_ptr.is_null() {
88 return Err(ShapeError::RuntimeError {
89 message: "Language runtime returned null language_id".to_string(),
90 location: None,
91 });
92 } else {
93 unsafe { CStr::from_ptr(lang_ptr) }
94 .to_string_lossy()
95 .to_string()
96 };
97
98 let error_model = vtable.error_model;
99 let state = Arc::new(LanguageRuntimeState { vtable, instance });
100
101 Ok(Self {
102 language_id,
103 state,
104 error_model,
105 })
106 }
107
108 pub fn language_id(&self) -> &str {
110 &self.language_id
111 }
112
113 pub fn has_dynamic_errors(&self) -> bool {
118 self.error_model == ErrorModel::Dynamic
119 }
120
121 pub fn lsp_config(&self) -> Result<Option<RuntimeLspConfig>> {
123 let get_lsp_config = match self.state.vtable.get_lsp_config {
124 Some(f) => f,
125 None => return Ok(None),
126 };
127
128 let mut out_ptr: *mut u8 = std::ptr::null_mut();
129 let mut out_len: usize = 0;
130 let rc = unsafe { get_lsp_config(self.state.instance, &mut out_ptr, &mut out_len) };
131 if rc != 0 {
132 return Err(ShapeError::RuntimeError {
133 message: format!(
134 "Language runtime '{}' get_lsp_config failed (error code {})",
135 self.language_id, rc
136 ),
137 location: None,
138 });
139 }
140
141 if out_ptr.is_null() || out_len == 0 {
142 return Ok(None);
143 }
144
145 let bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len) }.to_vec();
146 if let Some(free_fn) = self.state.vtable.free_buffer {
147 unsafe { free_fn(out_ptr, out_len) };
148 }
149
150 let decoded: LanguageRuntimeLspConfig =
151 rmp_serde::from_slice(&bytes).map_err(|e| ShapeError::RuntimeError {
152 message: format!(
153 "Language runtime '{}' returned invalid get_lsp_config payload: {}",
154 self.language_id, e
155 ),
156 location: None,
157 })?;
158
159 Ok(Some(RuntimeLspConfig {
160 language_id: decoded.language_id,
161 server_command: decoded.server_command,
162 file_extension: decoded.file_extension,
163 extra_paths: decoded.extra_paths,
164 }))
165 }
166
167 pub fn register_types(&self, types_msgpack: &[u8]) -> Result<()> {
169 let register_fn = match self.state.vtable.register_types {
170 Some(f) => f,
171 None => return Ok(()), };
173
174 let rc = unsafe {
175 register_fn(
176 self.state.instance,
177 types_msgpack.as_ptr(),
178 types_msgpack.len(),
179 )
180 };
181 if rc != 0 {
182 return Err(ShapeError::RuntimeError {
183 message: format!(
184 "Language runtime '{}' register_types failed (error code {})",
185 self.language_id, rc
186 ),
187 location: None,
188 });
189 }
190 Ok(())
191 }
192
193 pub fn compile(
195 &self,
196 name: &str,
197 source: &str,
198 param_names: &[String],
199 param_types: &[String],
200 return_type: Option<&str>,
201 is_async: bool,
202 ) -> Result<CompiledForeignFunction> {
203 let compile_fn = self
204 .state
205 .vtable
206 .compile
207 .ok_or_else(|| ShapeError::RuntimeError {
208 message: format!(
209 "Language runtime '{}' has no compile function",
210 self.language_id
211 ),
212 location: None,
213 })?;
214
215 let names_bytes = rmp_serde::to_vec(param_names).map_err(|e| ShapeError::RuntimeError {
216 message: format!("Failed to serialize param names: {}", e),
217 location: None,
218 })?;
219 let types_bytes = rmp_serde::to_vec(param_types).map_err(|e| ShapeError::RuntimeError {
220 message: format!("Failed to serialize param types: {}", e),
221 location: None,
222 })?;
223 let return_type_str = return_type.unwrap_or("");
224
225 let mut out_error: *mut u8 = std::ptr::null_mut();
226 let mut out_error_len: usize = 0;
227
228 let handle = unsafe {
229 compile_fn(
230 self.state.instance,
231 name.as_ptr(),
232 name.len(),
233 source.as_ptr(),
234 source.len(),
235 names_bytes.as_ptr(),
236 names_bytes.len(),
237 types_bytes.as_ptr(),
238 types_bytes.len(),
239 return_type_str.as_ptr(),
240 return_type_str.len(),
241 is_async,
242 &mut out_error,
243 &mut out_error_len,
244 )
245 };
246
247 if handle.is_null() {
248 let msg = if !out_error.is_null() && out_error_len > 0 {
249 let error_bytes =
250 unsafe { std::slice::from_raw_parts(out_error, out_error_len) }.to_vec();
251 if let Some(free_fn) = self.state.vtable.free_buffer {
252 unsafe { free_fn(out_error, out_error_len) };
253 }
254 String::from_utf8_lossy(&error_bytes).to_string()
255 } else {
256 "unknown compilation error".to_string()
257 };
258
259 return Err(ShapeError::RuntimeError {
260 message: format!(
261 "Language runtime '{}' failed to compile foreign function '{}': {}",
262 self.language_id, name, msg
263 ),
264 location: None,
265 });
266 }
267
268 Ok(CompiledForeignFunction {
269 handle,
270 _runtime: Arc::clone(&self.state),
271 })
272 }
273
274 pub fn invoke(
276 &self,
277 compiled: &CompiledForeignFunction,
278 args_msgpack: &[u8],
279 ) -> Result<Vec<u8>> {
280 let invoke_fn = self
281 .state
282 .vtable
283 .invoke
284 .ok_or_else(|| ShapeError::RuntimeError {
285 message: format!(
286 "Language runtime '{}' has no invoke function",
287 self.language_id
288 ),
289 location: None,
290 })?;
291
292 let mut out_ptr: *mut u8 = std::ptr::null_mut();
293 let mut out_len: usize = 0;
294
295 let rc = unsafe {
296 invoke_fn(
297 self.state.instance,
298 compiled.handle,
299 args_msgpack.as_ptr(),
300 args_msgpack.len(),
301 &mut out_ptr,
302 &mut out_len,
303 )
304 };
305
306 if rc != 0 {
307 let msg = if !out_ptr.is_null() && out_len > 0 {
309 let error_bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len) }.to_vec();
310 if let Some(free_fn) = self.state.vtable.free_buffer {
311 unsafe { free_fn(out_ptr, out_len) };
312 }
313 String::from_utf8_lossy(&error_bytes).to_string()
314 } else {
315 format!("error code {}", rc)
316 };
317 return Err(ShapeError::RuntimeError {
318 message: format!(
319 "Language runtime '{}' invoke failed: {}",
320 self.language_id, msg
321 ),
322 location: None,
323 });
324 }
325
326 if out_ptr.is_null() || out_len == 0 {
327 return Ok(vec![]);
328 }
329
330 let result = unsafe { std::slice::from_raw_parts(out_ptr, out_len) }.to_vec();
331
332 if let Some(free_fn) = self.state.vtable.free_buffer {
334 unsafe { free_fn(out_ptr, out_len) };
335 }
336
337 Ok(result)
338 }
339
340 pub fn dispose_function(&self, compiled: &CompiledForeignFunction) {
342 if let Some(dispose_fn) = self.state.vtable.dispose_function {
343 unsafe {
344 dispose_fn(self.state.instance, compiled.handle);
345 }
346 }
347 }
348
349 pub fn shape_source(&self) -> Result<Option<(String, String)>> {
357 let get_source_fn = match self.state.vtable.get_shape_source {
358 Some(f) => f,
359 None => return Ok(None),
360 };
361
362 let mut out_ptr: *mut u8 = std::ptr::null_mut();
363 let mut out_len: usize = 0;
364 let rc = unsafe { get_source_fn(self.state.instance, &mut out_ptr, &mut out_len) };
365 if rc != 0 {
366 return Err(ShapeError::RuntimeError {
367 message: format!(
368 "Language runtime '{}' get_shape_source failed (error code {})",
369 self.language_id, rc
370 ),
371 location: None,
372 });
373 }
374
375 if out_ptr.is_null() || out_len == 0 {
376 return Ok(None);
377 }
378
379 let bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len) }.to_vec();
380 if let Some(free_fn) = self.state.vtable.free_buffer {
381 unsafe { free_fn(out_ptr, out_len) };
382 }
383
384 let source = String::from_utf8(bytes).map_err(|e| ShapeError::RuntimeError {
385 message: format!(
386 "Language runtime '{}' returned invalid UTF-8 shape source: {}",
387 self.language_id, e
388 ),
389 location: None,
390 })?;
391
392 Ok(Some((self.language_id.clone(), source)))
395 }
396}