1use crate::error_mapping;
10use crate::marshaling;
11use shape_abi_v1::{LanguageRuntimeLspConfig, PluginError};
12use std::collections::HashMap;
13use std::ffi::c_void;
14
15pub struct CompiledFunction {
17 pub name: String,
19 pub python_source: String,
21 pub param_names: Vec<String>,
23 pub shape_body_start_line: u32,
25 pub is_async: bool,
27 pub return_type: String,
30}
31
32pub struct PythonRuntime {
34 functions: HashMap<usize, CompiledFunction>,
36 next_id: usize,
38}
39
40impl PythonRuntime {
41 pub fn new(_config_msgpack: &[u8]) -> Result<Self, String> {
47 #[cfg(feature = "pyo3")]
48 {
49 Self::activate_virtualenv();
53 }
54
55 Ok(PythonRuntime {
56 functions: HashMap::new(),
57 next_id: 1,
58 })
59 }
60
61 #[cfg(feature = "pyo3")]
73 fn activate_virtualenv() {
74 use pyo3::prelude::*;
75
76 let cwd = std::env::current_dir().ok();
77
78 let from_pyright_config = cwd.as_ref().and_then(|cwd| {
80 let config_path = cwd.join("pyrightconfig.json");
81 let contents = std::fs::read_to_string(&config_path).ok()?;
82 let config: serde_json::Value = serde_json::from_str(&contents).ok()?;
83 let venv_path = config.get("venvPath")?.as_str()?;
84 let venv_name = config.get("venv")?.as_str()?;
85 let base = if std::path::Path::new(venv_path).is_absolute() {
86 std::path::PathBuf::from(venv_path)
87 } else {
88 cwd.join(venv_path)
89 };
90 let candidate = base.join(venv_name);
91 candidate.is_dir().then_some(candidate)
92 });
93
94 let from_local_dir = || -> Option<std::path::PathBuf> {
96 let cwd = cwd.as_ref()?;
97 for name in &[".venv", "venv"] {
98 let candidate = cwd.join(name);
99 if candidate.is_dir() {
100 return Some(candidate);
101 }
102 }
103 None
104 };
105
106 let from_env = || -> Option<std::path::PathBuf> {
108 let path = std::path::PathBuf::from(std::env::var("VIRTUAL_ENV").ok()?);
109 path.is_dir().then_some(path)
110 };
111
112 let venv = from_pyright_config
113 .or_else(from_local_dir)
114 .or_else(from_env);
115
116 let Some(venv) = venv else { return };
117 let venv_str = venv.display().to_string();
118
119 Python::attach(|py| {
120 let code = format!(
121 concat!(
122 "import sys, site, os\n",
123 "venv = \"{venv}\"\n",
124 "sys.prefix = venv\n",
125 "sys.exec_prefix = venv\n",
126 "lib_dir = os.path.join(venv, \"lib\")\n",
127 "if os.path.isdir(lib_dir):\n",
128 " for entry in os.listdir(lib_dir):\n",
129 " sp = os.path.join(lib_dir, entry, \"site-packages\")\n",
130 " if os.path.isdir(sp):\n",
131 " site.addsitedir(sp)\n",
132 " break\n",
133 ),
134 venv = venv_str,
135 );
136
137 if let Err(e) = py.run(&std::ffi::CString::new(code).unwrap(), None, None) {
138 eprintln!("shape-ext-python: failed to activate venv: {e}");
139 }
140 });
141 }
142
143 pub fn register_types(&mut self, _types_msgpack: &[u8]) -> Result<(), String> {
148 Ok(())
152 }
153
154 pub fn compile(
173 &mut self,
174 name: &str,
175 source: &str,
176 param_names: &[String],
177 param_types: &[String],
178 return_type: &str,
179 is_async: bool,
180 ) -> Result<*mut c_void, String> {
181 let params: Vec<String> = param_names
183 .iter()
184 .zip(param_types.iter())
185 .map(|(pname, ptype)| {
186 format!(
187 "{}: {}",
188 pname,
189 marshaling::shape_type_to_python_hint(ptype)
190 )
191 })
192 .collect();
193 let params_str = params.join(", ");
194 let return_hint = marshaling::shape_type_to_python_hint(return_type);
195
196 let indented_body: String = source
198 .lines()
199 .map(|line| format!(" {line}"))
200 .collect::<Vec<_>>()
201 .join("\n");
202
203 let python_source = if is_async {
204 let plain_params: Vec<&str> = param_names.iter().map(|s| s.as_str()).collect();
206 let call_args = plain_params.join(", ");
207 format!(
208 "import asyncio\n\
209 async def __shape_async__({params_str}) -> {return_hint}:\n\
210 {indented_body}\n\
211 def __shape_fn__({params_str}) -> {return_hint}:\n\
212 {sync_indent}return asyncio.run(__shape_async__({call_args}))\n",
213 sync_indent = " ",
214 )
215 } else {
216 format!("def __shape_fn__({params_str}) -> {return_hint}:\n{indented_body}")
217 };
218
219 let id = self.next_id;
220 self.next_id += 1;
221
222 let func = CompiledFunction {
223 name: name.to_string(),
224 python_source,
225 param_names: param_names.to_vec(),
226 shape_body_start_line: 0,
227 is_async,
228 return_type: return_type.to_string(),
229 };
230
231 self.functions.insert(id, func);
232
233 Ok(id as *mut c_void)
235 }
236
237 pub fn invoke(&self, handle: *mut c_void, args_msgpack: &[u8]) -> Result<Vec<u8>, String> {
241 let id = handle as usize;
242 let func = self
243 .functions
244 .get(&id)
245 .ok_or_else(|| format!("invalid function handle: {id}"))?;
246
247 #[cfg(feature = "pyo3")]
248 {
249 use pyo3::prelude::*;
250 use pyo3::types::PyModule;
251
252 Python::attach(|py| {
253 let source_cstring = std::ffi::CString::new(func.python_source.as_str())
255 .map_err(|e| format!("Invalid source (contains null byte): {}", e))?;
256 let code = PyModule::from_code(py, &source_cstring, c"<shape>", c"__shape__")
257 .map_err(|e| error_mapping::format_python_error(py, &e, func))?;
258
259 let shape_fn = code
260 .getattr("__shape_fn__")
261 .map_err(|e| error_mapping::format_python_error(py, &e, func))?;
262
263 let args_values: Vec<rmpv::Value> = if args_msgpack.is_empty() {
265 Vec::new()
266 } else {
267 rmp_serde::from_slice(args_msgpack)
268 .map_err(|e| format!("Failed to deserialize args: {}", e))?
269 };
270
271 let py_args: Vec<pyo3::Py<pyo3::PyAny>> = args_values
272 .iter()
273 .map(|v| marshaling::msgpack_to_pyobject(py, v))
274 .collect::<Result<_, _>>()?;
275
276 let py_tuple = pyo3::types::PyTuple::new(py, &py_args)
278 .map_err(|e| format!("Failed to create args tuple: {}", e))?;
279 let result = shape_fn
280 .call1(&py_tuple)
281 .map_err(|e| error_mapping::format_python_error(py, &e, func))?;
282
283 let result_value =
285 marshaling::pyobject_to_typed_msgpack(py, &result, &func.return_type)?;
286 rmp_serde::to_vec(&result_value)
287 .map_err(|e| format!("Failed to serialize result: {}", e))
288 })
289 }
290
291 #[cfg(not(feature = "pyo3"))]
292 {
293 let _ = args_msgpack;
294 let _ = &func.python_source;
295 let _ = error_mapping::parse_traceback;
296 Err(format!(
297 "python runtime: pyo3 feature not enabled (function: {})",
298 func.name
299 ))
300 }
301 }
302
303 pub fn dispose_function(&mut self, handle: *mut c_void) {
305 let id = handle as usize;
306 self.functions.remove(&id);
307 }
308
309 pub fn language_id() -> &'static str {
311 "python"
312 }
313
314 pub fn lsp_config() -> LanguageRuntimeLspConfig {
316 LanguageRuntimeLspConfig {
317 language_id: "python".into(),
318 server_command: vec!["pyright-langserver".into(), "--stdio".into()],
319 file_extension: ".py".into(),
320 extra_paths: Vec::new(),
321 }
322 }
323}
324
325pub unsafe extern "C" fn python_init(config: *const u8, config_len: usize) -> *mut c_void {
330 #[cfg(unix)]
337 promote_libpython_symbols();
338
339 let config_slice = if config.is_null() || config_len == 0 {
340 &[]
341 } else {
342 unsafe { std::slice::from_raw_parts(config, config_len) }
343 };
344
345 match PythonRuntime::new(config_slice) {
346 Ok(runtime) => Box::into_raw(Box::new(runtime)) as *mut c_void,
347 Err(_) => std::ptr::null_mut(),
348 }
349}
350
351pub unsafe extern "C" fn python_register_types(
352 instance: *mut c_void,
353 types_msgpack: *const u8,
354 types_len: usize,
355) -> i32 {
356 if instance.is_null() {
357 return PluginError::NotInitialized as i32;
358 }
359 let runtime = unsafe { &mut *(instance as *mut PythonRuntime) };
360 let types_slice = if types_msgpack.is_null() || types_len == 0 {
361 &[]
362 } else {
363 unsafe { std::slice::from_raw_parts(types_msgpack, types_len) }
364 };
365
366 match runtime.register_types(types_slice) {
367 Ok(()) => PluginError::Success as i32,
368 Err(_) => PluginError::InternalError as i32,
369 }
370}
371
372pub unsafe extern "C" fn python_compile(
373 instance: *mut c_void,
374 name: *const u8,
375 name_len: usize,
376 source: *const u8,
377 source_len: usize,
378 param_names_msgpack: *const u8,
379 param_names_len: usize,
380 param_types_msgpack: *const u8,
381 param_types_len: usize,
382 return_type: *const u8,
383 return_type_len: usize,
384 is_async: bool,
385 out_error: *mut *mut u8,
386 out_error_len: *mut usize,
387) -> *mut c_void {
388 if instance.is_null() {
389 return std::ptr::null_mut();
390 }
391 let runtime = unsafe { &mut *(instance as *mut PythonRuntime) };
392
393 let name_str = match str_from_raw(name, name_len) {
394 Some(s) => s,
395 None => {
396 write_error(out_error, out_error_len, "invalid function name");
397 return std::ptr::null_mut();
398 }
399 };
400 let source_str = match str_from_raw(source, source_len) {
401 Some(s) => s,
402 None => {
403 write_error(out_error, out_error_len, "invalid source text");
404 return std::ptr::null_mut();
405 }
406 };
407 let return_type_str = match str_from_raw(return_type, return_type_len) {
408 Some(s) => s,
409 None => "any", };
411
412 let param_names: Vec<String> = if param_names_msgpack.is_null() || param_names_len == 0 {
413 Vec::new()
414 } else {
415 let slice = unsafe { std::slice::from_raw_parts(param_names_msgpack, param_names_len) };
416 match rmp_serde::from_slice(slice) {
417 Ok(v) => v,
418 Err(_) => {
419 write_error(out_error, out_error_len, "invalid param names msgpack");
420 return std::ptr::null_mut();
421 }
422 }
423 };
424
425 let param_types: Vec<String> = if param_types_msgpack.is_null() || param_types_len == 0 {
426 Vec::new()
427 } else {
428 let slice = unsafe { std::slice::from_raw_parts(param_types_msgpack, param_types_len) };
429 match rmp_serde::from_slice(slice) {
430 Ok(v) => v,
431 Err(_) => {
432 write_error(out_error, out_error_len, "invalid param types msgpack");
433 return std::ptr::null_mut();
434 }
435 }
436 };
437
438 match runtime.compile(
439 name_str,
440 source_str,
441 ¶m_names,
442 ¶m_types,
443 return_type_str,
444 is_async,
445 ) {
446 Ok(handle) => handle,
447 Err(msg) => {
448 write_error(out_error, out_error_len, &msg);
449 std::ptr::null_mut()
450 }
451 }
452}
453
454fn write_error(out_error: *mut *mut u8, out_error_len: *mut usize, msg: &str) {
456 if out_error.is_null() || out_error_len.is_null() {
457 return;
458 }
459 let mut bytes = msg.as_bytes().to_vec();
460 let len = bytes.len();
461 let ptr = bytes.as_mut_ptr();
462 std::mem::forget(bytes);
463 unsafe {
464 *out_error = ptr;
465 *out_error_len = len;
466 }
467}
468
469pub unsafe extern "C" fn python_invoke(
470 instance: *mut c_void,
471 handle: *mut c_void,
472 args_msgpack: *const u8,
473 args_len: usize,
474 out_ptr: *mut *mut u8,
475 out_len: *mut usize,
476) -> i32 {
477 if instance.is_null() || out_ptr.is_null() || out_len.is_null() {
478 return PluginError::InvalidArgument as i32;
479 }
480 let runtime = unsafe { &*(instance as *const PythonRuntime) };
481 let args_slice = if args_msgpack.is_null() || args_len == 0 {
482 &[]
483 } else {
484 unsafe { std::slice::from_raw_parts(args_msgpack, args_len) }
485 };
486
487 match runtime.invoke(handle, args_slice) {
488 Ok(mut bytes) => {
489 let len = bytes.len();
490 let ptr = bytes.as_mut_ptr();
491 std::mem::forget(bytes);
492 unsafe {
493 *out_ptr = ptr;
494 *out_len = len;
495 }
496 PluginError::Success as i32
497 }
498 Err(msg) => {
499 let mut err_bytes = msg.into_bytes();
501 let len = err_bytes.len();
502 let ptr = err_bytes.as_mut_ptr();
503 std::mem::forget(err_bytes);
504 unsafe {
505 *out_ptr = ptr;
506 *out_len = len;
507 }
508 PluginError::NotImplemented as i32
509 }
510 }
511}
512
513pub unsafe extern "C" fn python_dispose_function(instance: *mut c_void, handle: *mut c_void) {
514 if instance.is_null() {
515 return;
516 }
517 let runtime = unsafe { &mut *(instance as *mut PythonRuntime) };
518 runtime.dispose_function(handle);
519}
520
521pub unsafe extern "C" fn python_language_id(_instance: *mut c_void) -> *const std::ffi::c_char {
522 c"python".as_ptr()
524}
525
526pub unsafe extern "C" fn python_get_lsp_config(
527 _instance: *mut c_void,
528 out_ptr: *mut *mut u8,
529 out_len: *mut usize,
530) -> i32 {
531 if out_ptr.is_null() || out_len.is_null() {
532 return PluginError::InvalidArgument as i32;
533 }
534 let config = PythonRuntime::lsp_config();
535 match rmp_serde::to_vec(&config) {
536 Ok(mut bytes) => {
537 let len = bytes.len();
538 let ptr = bytes.as_mut_ptr();
539 std::mem::forget(bytes);
540 unsafe {
541 *out_ptr = ptr;
542 *out_len = len;
543 }
544 PluginError::Success as i32
545 }
546 Err(_) => PluginError::InternalError as i32,
547 }
548}
549
550pub unsafe extern "C" fn python_free_buffer(ptr: *mut u8, len: usize) {
551 if !ptr.is_null() && len > 0 {
552 let _ = unsafe { Vec::from_raw_parts(ptr, len, len) };
553 }
554}
555
556pub unsafe extern "C" fn python_drop(instance: *mut c_void) {
557 if !instance.is_null() {
558 let _ = unsafe { Box::from_raw(instance as *mut PythonRuntime) };
559 }
560}
561
562#[cfg(unix)]
572fn promote_libpython_symbols() {
573 const SONAMES: &[&[u8]] = &[
574 b"libpython3.13.so.1.0\0",
575 b"libpython3.13.so\0",
576 b"libpython3.12.so.1.0\0",
577 b"libpython3.12.so\0",
578 b"libpython3.11.so.1.0\0",
579 b"libpython3.11.so\0",
580 b"libpython3.so\0",
581 ];
582 for soname in SONAMES {
583 let handle = unsafe {
584 libc::dlopen(
585 soname.as_ptr() as *const std::ffi::c_char,
586 libc::RTLD_NOLOAD | libc::RTLD_NOW | libc::RTLD_GLOBAL,
587 )
588 };
589 if !handle.is_null() {
590 unsafe { libc::dlclose(handle) };
591 return;
592 }
593 }
594 }
597
598fn str_from_raw<'a>(ptr: *const u8, len: usize) -> Option<&'a str> {
599 if ptr.is_null() || len == 0 {
600 return None;
601 }
602 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
603 std::str::from_utf8(slice).ok()
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609
610 #[test]
611 fn lsp_config_exposes_pyright_defaults() {
612 let config = PythonRuntime::lsp_config();
613 assert_eq!(config.language_id, "python");
614 assert_eq!(
615 config.server_command,
616 vec!["pyright-langserver".to_string(), "--stdio".to_string()]
617 );
618 assert_eq!(config.file_extension, ".py");
619 assert!(config.extra_paths.is_empty());
620 }
621
622 #[test]
623 fn python_get_lsp_config_returns_valid_msgpack_payload() {
624 let mut out_ptr: *mut u8 = std::ptr::null_mut();
625 let mut out_len: usize = 0;
626
627 let code =
628 unsafe { python_get_lsp_config(std::ptr::null_mut(), &mut out_ptr, &mut out_len) };
629 assert_eq!(code, PluginError::Success as i32);
630 assert!(!out_ptr.is_null());
631 assert!(out_len > 0);
632
633 let bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len) };
634 let decoded: LanguageRuntimeLspConfig =
635 rmp_serde::from_slice(bytes).expect("payload should decode");
636 assert_eq!(decoded.language_id, "python");
637 assert_eq!(decoded.file_extension, ".py");
638
639 unsafe { python_free_buffer(out_ptr, out_len) };
640 }
641}