Skip to main content

shape_ext_python/
runtime.rs

1//! CPython interpreter management and foreign function execution.
2//!
3//! This module owns the Python interpreter lifecycle and implements the
4//! core LanguageRuntime operations: init, compile, invoke, dispose.
5//!
6//! When the `pyo3` feature is enabled, this uses pyo3 to embed CPython.
7//! Without it, all operations return stub errors.
8
9use crate::error_mapping;
10use crate::marshaling;
11use shape_abi_v1::{LanguageRuntimeLspConfig, PluginError};
12use std::collections::HashMap;
13use std::ffi::c_void;
14
15/// Opaque handle to a compiled Python function.
16pub struct CompiledFunction {
17    /// The function name in Shape.
18    pub name: String,
19    /// Generated Python source for the wrapper function.
20    pub python_source: String,
21    /// Parameter names in call order.
22    pub param_names: Vec<String>,
23    /// Shape source line where the foreign block body starts (for error mapping).
24    pub shape_body_start_line: u32,
25    /// Whether the function was declared `async` in Shape.
26    pub is_async: bool,
27    /// Declared return type string from Shape (e.g. "Result<int>", "Result<{id: int, name: string}>").
28    /// Used by the typed marshalling path to validate/coerce Python return values.
29    pub return_type: String,
30}
31
32/// The Python runtime instance. One per `init()` call.
33pub struct PythonRuntime {
34    /// Compiled function handles, keyed by an incrementing ID.
35    functions: HashMap<usize, CompiledFunction>,
36    /// Next handle ID.
37    next_id: usize,
38}
39
40impl PythonRuntime {
41    /// Initialize a new Python runtime.
42    ///
43    /// `_config_msgpack` is the MessagePack-encoded configuration from the
44    /// host. Currently unused -- reserved for future settings like
45    /// virtualenv path, Python version constraints, etc.
46    pub fn new(_config_msgpack: &[u8]) -> Result<Self, String> {
47        #[cfg(feature = "pyo3")]
48        {
49            // Activate virtualenv if one is detected. This mirrors what
50            // `source .venv/bin/activate` does: update sys.prefix and add
51            // site-packages so that `import <pkg>` works for venv packages.
52            Self::activate_virtualenv();
53        }
54
55        Ok(PythonRuntime {
56            functions: HashMap::new(),
57            next_id: 1,
58        })
59    }
60
61    /// Detect and activate a Python virtualenv.
62    ///
63    /// Mirrors Pyright's discovery order so the runtime resolves the same
64    /// environment as the language server:
65    /// 1. `pyrightconfig.json` `venvPath` + `venv` in the working directory
66    /// 2. `.venv/` in the working directory
67    /// 3. `venv/` in the working directory
68    /// 4. `VIRTUAL_ENV` environment variable
69    ///
70    /// When found, adds the venv's site-packages to `sys.path` and updates
71    /// `sys.prefix` so that `import <pkg>` works for venv-installed packages.
72    #[cfg(feature = "pyo3")]
73    fn activate_virtualenv() {
74        use pyo3::prelude::*;
75
76        let cwd = std::env::current_dir().ok();
77
78        // 1. Check pyrightconfig.json for venvPath + venv
79        let from_pyright_config = cwd.as_ref().and_then(|cwd| {
80            let config_path = cwd.join("pyrightconfig.json");
81            let contents = std::fs::read_to_string(&config_path).ok()?;
82            let config: serde_json::Value = serde_json::from_str(&contents).ok()?;
83            let venv_path = config.get("venvPath")?.as_str()?;
84            let venv_name = config.get("venv")?.as_str()?;
85            let base = if std::path::Path::new(venv_path).is_absolute() {
86                std::path::PathBuf::from(venv_path)
87            } else {
88                cwd.join(venv_path)
89            };
90            let candidate = base.join(venv_name);
91            candidate.is_dir().then_some(candidate)
92        });
93
94        // 2-3. Check .venv/ and venv/ in working directory
95        let from_local_dir = || -> Option<std::path::PathBuf> {
96            let cwd = cwd.as_ref()?;
97            for name in &[".venv", "venv"] {
98                let candidate = cwd.join(name);
99                if candidate.is_dir() {
100                    return Some(candidate);
101                }
102            }
103            None
104        };
105
106        // 4. VIRTUAL_ENV environment variable
107        let from_env = || -> Option<std::path::PathBuf> {
108            let path = std::path::PathBuf::from(std::env::var("VIRTUAL_ENV").ok()?);
109            path.is_dir().then_some(path)
110        };
111
112        let venv = from_pyright_config
113            .or_else(from_local_dir)
114            .or_else(from_env);
115
116        let Some(venv) = venv else { return };
117        let venv_str = venv.display().to_string();
118
119        Python::attach(|py| {
120            let code = format!(
121                concat!(
122                    "import sys, site, os\n",
123                    "venv = \"{venv}\"\n",
124                    "sys.prefix = venv\n",
125                    "sys.exec_prefix = venv\n",
126                    "lib_dir = os.path.join(venv, \"lib\")\n",
127                    "if os.path.isdir(lib_dir):\n",
128                    "    for entry in os.listdir(lib_dir):\n",
129                    "        sp = os.path.join(lib_dir, entry, \"site-packages\")\n",
130                    "        if os.path.isdir(sp):\n",
131                    "            site.addsitedir(sp)\n",
132                    "            break\n",
133                ),
134                venv = venv_str,
135            );
136
137            if let Err(e) = py.run(&std::ffi::CString::new(code).unwrap(), None, None) {
138                eprintln!("shape-ext-python: failed to activate venv: {e}");
139            }
140        });
141    }
142
143    /// Register Shape type schemas for Python stub generation.
144    ///
145    /// The runtime receives the full set of Shape types so it can generate
146    /// Python dataclass stubs that the user's code can reference.
147    pub fn register_types(&mut self, _types_msgpack: &[u8]) -> Result<(), String> {
148        // Stub: the real implementation will deserialize TypeSchemaExport[]
149        // and generate Python dataclass definitions injected into the
150        // interpreter's namespace.
151        Ok(())
152    }
153
154    /// Compile a foreign function body into a callable Python function.
155    ///
156    /// When `is_async` is false, wraps the user's body in:
157    /// ```python
158    /// def __shape_fn__(param1, param2) -> return_type:
159    ///     <body>
160    /// ```
161    ///
162    /// When `is_async` is true, wraps it in an async def with an asyncio runner:
163    /// ```python
164    /// import asyncio
165    /// async def __shape_async__(param1, param2) -> return_type:
166    ///     <body>
167    /// def __shape_fn__(param1, param2) -> return_type:
168    ///     return asyncio.run(__shape_async__(param1, param2))
169    /// ```
170    ///
171    /// Returns a handle that can be passed to `invoke()`.
172    pub fn compile(
173        &mut self,
174        name: &str,
175        source: &str,
176        param_names: &[String],
177        param_types: &[String],
178        return_type: &str,
179        is_async: bool,
180    ) -> Result<*mut c_void, String> {
181        // Build type-hinted parameter list.
182        let params: Vec<String> = param_names
183            .iter()
184            .zip(param_types.iter())
185            .map(|(pname, ptype)| {
186                format!(
187                    "{}: {}",
188                    pname,
189                    marshaling::shape_type_to_python_hint(ptype)
190                )
191            })
192            .collect();
193        let params_str = params.join(", ");
194        let return_hint = marshaling::shape_type_to_python_hint(return_type);
195
196        // Indent the user body by 4 spaces.
197        let indented_body: String = source
198            .lines()
199            .map(|line| format!("    {line}"))
200            .collect::<Vec<_>>()
201            .join("\n");
202
203        let python_source = if is_async {
204            // Wrap in async def + synchronous asyncio.run() entry point.
205            let plain_params: Vec<&str> = param_names.iter().map(|s| s.as_str()).collect();
206            let call_args = plain_params.join(", ");
207            format!(
208                "import asyncio\n\
209                 async def __shape_async__({params_str}) -> {return_hint}:\n\
210                 {indented_body}\n\
211                 def __shape_fn__({params_str}) -> {return_hint}:\n\
212                 {sync_indent}return asyncio.run(__shape_async__({call_args}))\n",
213                sync_indent = "    ",
214            )
215        } else {
216            format!("def __shape_fn__({params_str}) -> {return_hint}:\n{indented_body}")
217        };
218
219        let id = self.next_id;
220        self.next_id += 1;
221
222        let func = CompiledFunction {
223            name: name.to_string(),
224            python_source,
225            param_names: param_names.to_vec(),
226            shape_body_start_line: 0,
227            is_async,
228            return_type: return_type.to_string(),
229        };
230
231        self.functions.insert(id, func);
232
233        // The handle is the function ID cast to a pointer.
234        Ok(id as *mut c_void)
235    }
236
237    /// Invoke a previously compiled function with msgpack-encoded arguments.
238    ///
239    /// Returns msgpack-encoded result on success.
240    pub fn invoke(&self, handle: *mut c_void, args_msgpack: &[u8]) -> Result<Vec<u8>, String> {
241        let id = handle as usize;
242        let func = self
243            .functions
244            .get(&id)
245            .ok_or_else(|| format!("invalid function handle: {id}"))?;
246
247        #[cfg(feature = "pyo3")]
248        {
249            use pyo3::prelude::*;
250            use pyo3::types::PyModule;
251
252            Python::attach(|py| {
253                // 1. Execute the compiled source to define __shape_fn__
254                let source_cstring = std::ffi::CString::new(func.python_source.as_str())
255                    .map_err(|e| format!("Invalid source (contains null byte): {}", e))?;
256                let code = PyModule::from_code(py, &source_cstring, c"<shape>", c"__shape__")
257                    .map_err(|e| error_mapping::format_python_error(py, &e, func))?;
258
259                let shape_fn = code
260                    .getattr("__shape_fn__")
261                    .map_err(|e| error_mapping::format_python_error(py, &e, func))?;
262
263                // 2. Deserialize msgpack args -> Vec<rmpv::Value> -> Vec<Py<PyAny>>
264                let args_values: Vec<rmpv::Value> = if args_msgpack.is_empty() {
265                    Vec::new()
266                } else {
267                    rmp_serde::from_slice(args_msgpack)
268                        .map_err(|e| format!("Failed to deserialize args: {}", e))?
269                };
270
271                let py_args: Vec<pyo3::Py<pyo3::PyAny>> = args_values
272                    .iter()
273                    .map(|v| marshaling::msgpack_to_pyobject(py, v))
274                    .collect::<Result<_, _>>()?;
275
276                // 3. Call the function
277                let py_tuple = pyo3::types::PyTuple::new(py, &py_args)
278                    .map_err(|e| format!("Failed to create args tuple: {}", e))?;
279                let result = shape_fn
280                    .call1(&py_tuple)
281                    .map_err(|e| error_mapping::format_python_error(py, &e, func))?;
282
283                // 4. Convert result -> msgpack (type-aware path)
284                let result_value =
285                    marshaling::pyobject_to_typed_msgpack(py, &result, &func.return_type)?;
286                rmp_serde::to_vec(&result_value)
287                    .map_err(|e| format!("Failed to serialize result: {}", e))
288            })
289        }
290
291        #[cfg(not(feature = "pyo3"))]
292        {
293            let _ = args_msgpack;
294            let _ = &func.python_source;
295            let _ = error_mapping::parse_traceback;
296            Err(format!(
297                "python runtime: pyo3 feature not enabled (function: {})",
298                func.name
299            ))
300        }
301    }
302
303    /// Dispose a compiled function handle, freeing associated resources.
304    pub fn dispose_function(&mut self, handle: *mut c_void) {
305        let id = handle as usize;
306        self.functions.remove(&id);
307    }
308
309    /// Return the language identifier.
310    pub fn language_id() -> &'static str {
311        "python"
312    }
313
314    /// Return LSP configuration for Python (pyright).
315    pub fn lsp_config() -> LanguageRuntimeLspConfig {
316        LanguageRuntimeLspConfig {
317            language_id: "python".into(),
318            server_command: vec!["pyright-langserver".into(), "--stdio".into()],
319            file_extension: ".py".into(),
320            extra_paths: Vec::new(),
321        }
322    }
323}
324
325// ============================================================================
326// C ABI callback functions (wired from lib.rs vtable)
327// ============================================================================
328
329pub unsafe extern "C" fn python_init(config: *const u8, config_len: usize) -> *mut c_void {
330    // Promote libpython symbols to global visibility before any Python code
331    // runs. Python C extensions (numpy, pandas, etc.) loaded via `import`
332    // expect CPython API symbols (PyExc_ValueError, etc.) to be globally
333    // visible. Since the host loads this .so with RTLD_LOCAL, libpython's
334    // symbols are hidden. Re-opening with RTLD_NOLOAD | RTLD_GLOBAL
335    // promotes them without loading a second copy.
336    #[cfg(unix)]
337    promote_libpython_symbols();
338
339    let config_slice = if config.is_null() || config_len == 0 {
340        &[]
341    } else {
342        unsafe { std::slice::from_raw_parts(config, config_len) }
343    };
344
345    match PythonRuntime::new(config_slice) {
346        Ok(runtime) => Box::into_raw(Box::new(runtime)) as *mut c_void,
347        Err(_) => std::ptr::null_mut(),
348    }
349}
350
351pub unsafe extern "C" fn python_register_types(
352    instance: *mut c_void,
353    types_msgpack: *const u8,
354    types_len: usize,
355) -> i32 {
356    if instance.is_null() {
357        return PluginError::NotInitialized as i32;
358    }
359    let runtime = unsafe { &mut *(instance as *mut PythonRuntime) };
360    let types_slice = if types_msgpack.is_null() || types_len == 0 {
361        &[]
362    } else {
363        unsafe { std::slice::from_raw_parts(types_msgpack, types_len) }
364    };
365
366    match runtime.register_types(types_slice) {
367        Ok(()) => PluginError::Success as i32,
368        Err(_) => PluginError::InternalError as i32,
369    }
370}
371
372pub unsafe extern "C" fn python_compile(
373    instance: *mut c_void,
374    name: *const u8,
375    name_len: usize,
376    source: *const u8,
377    source_len: usize,
378    param_names_msgpack: *const u8,
379    param_names_len: usize,
380    param_types_msgpack: *const u8,
381    param_types_len: usize,
382    return_type: *const u8,
383    return_type_len: usize,
384    is_async: bool,
385    out_error: *mut *mut u8,
386    out_error_len: *mut usize,
387) -> *mut c_void {
388    if instance.is_null() {
389        return std::ptr::null_mut();
390    }
391    let runtime = unsafe { &mut *(instance as *mut PythonRuntime) };
392
393    let name_str = match str_from_raw(name, name_len) {
394        Some(s) => s,
395        None => {
396            write_error(out_error, out_error_len, "invalid function name");
397            return std::ptr::null_mut();
398        }
399    };
400    let source_str = match str_from_raw(source, source_len) {
401        Some(s) => s,
402        None => {
403            write_error(out_error, out_error_len, "invalid source text");
404            return std::ptr::null_mut();
405        }
406    };
407    let return_type_str = match str_from_raw(return_type, return_type_len) {
408        Some(s) => s,
409        None => "any", // Default to "any" for generic/complex return types
410    };
411
412    let param_names: Vec<String> = if param_names_msgpack.is_null() || param_names_len == 0 {
413        Vec::new()
414    } else {
415        let slice = unsafe { std::slice::from_raw_parts(param_names_msgpack, param_names_len) };
416        match rmp_serde::from_slice(slice) {
417            Ok(v) => v,
418            Err(_) => {
419                write_error(out_error, out_error_len, "invalid param names msgpack");
420                return std::ptr::null_mut();
421            }
422        }
423    };
424
425    let param_types: Vec<String> = if param_types_msgpack.is_null() || param_types_len == 0 {
426        Vec::new()
427    } else {
428        let slice = unsafe { std::slice::from_raw_parts(param_types_msgpack, param_types_len) };
429        match rmp_serde::from_slice(slice) {
430            Ok(v) => v,
431            Err(_) => {
432                write_error(out_error, out_error_len, "invalid param types msgpack");
433                return std::ptr::null_mut();
434            }
435        }
436    };
437
438    match runtime.compile(
439        name_str,
440        source_str,
441        &param_names,
442        &param_types,
443        return_type_str,
444        is_async,
445    ) {
446        Ok(handle) => handle,
447        Err(msg) => {
448            write_error(out_error, out_error_len, &msg);
449            std::ptr::null_mut()
450        }
451    }
452}
453
454/// Write a UTF-8 error message to out_error/out_error_len for the caller to free.
455fn write_error(out_error: *mut *mut u8, out_error_len: *mut usize, msg: &str) {
456    if out_error.is_null() || out_error_len.is_null() {
457        return;
458    }
459    let mut bytes = msg.as_bytes().to_vec();
460    let len = bytes.len();
461    let ptr = bytes.as_mut_ptr();
462    std::mem::forget(bytes);
463    unsafe {
464        *out_error = ptr;
465        *out_error_len = len;
466    }
467}
468
469pub unsafe extern "C" fn python_invoke(
470    instance: *mut c_void,
471    handle: *mut c_void,
472    args_msgpack: *const u8,
473    args_len: usize,
474    out_ptr: *mut *mut u8,
475    out_len: *mut usize,
476) -> i32 {
477    if instance.is_null() || out_ptr.is_null() || out_len.is_null() {
478        return PluginError::InvalidArgument as i32;
479    }
480    let runtime = unsafe { &*(instance as *const PythonRuntime) };
481    let args_slice = if args_msgpack.is_null() || args_len == 0 {
482        &[]
483    } else {
484        unsafe { std::slice::from_raw_parts(args_msgpack, args_len) }
485    };
486
487    match runtime.invoke(handle, args_slice) {
488        Ok(mut bytes) => {
489            let len = bytes.len();
490            let ptr = bytes.as_mut_ptr();
491            std::mem::forget(bytes);
492            unsafe {
493                *out_ptr = ptr;
494                *out_len = len;
495            }
496            PluginError::Success as i32
497        }
498        Err(msg) => {
499            // Write error message to output buffer so the host can read it
500            let mut err_bytes = msg.into_bytes();
501            let len = err_bytes.len();
502            let ptr = err_bytes.as_mut_ptr();
503            std::mem::forget(err_bytes);
504            unsafe {
505                *out_ptr = ptr;
506                *out_len = len;
507            }
508            PluginError::NotImplemented as i32
509        }
510    }
511}
512
513pub unsafe extern "C" fn python_dispose_function(instance: *mut c_void, handle: *mut c_void) {
514    if instance.is_null() {
515        return;
516    }
517    let runtime = unsafe { &mut *(instance as *mut PythonRuntime) };
518    runtime.dispose_function(handle);
519}
520
521pub unsafe extern "C" fn python_language_id(_instance: *mut c_void) -> *const std::ffi::c_char {
522    // "python\0" -- static, owned by the extension.
523    c"python".as_ptr()
524}
525
526pub unsafe extern "C" fn python_get_lsp_config(
527    _instance: *mut c_void,
528    out_ptr: *mut *mut u8,
529    out_len: *mut usize,
530) -> i32 {
531    if out_ptr.is_null() || out_len.is_null() {
532        return PluginError::InvalidArgument as i32;
533    }
534    let config = PythonRuntime::lsp_config();
535    match rmp_serde::to_vec(&config) {
536        Ok(mut bytes) => {
537            let len = bytes.len();
538            let ptr = bytes.as_mut_ptr();
539            std::mem::forget(bytes);
540            unsafe {
541                *out_ptr = ptr;
542                *out_len = len;
543            }
544            PluginError::Success as i32
545        }
546        Err(_) => PluginError::InternalError as i32,
547    }
548}
549
550pub unsafe extern "C" fn python_free_buffer(ptr: *mut u8, len: usize) {
551    if !ptr.is_null() && len > 0 {
552        let _ = unsafe { Vec::from_raw_parts(ptr, len, len) };
553    }
554}
555
556pub unsafe extern "C" fn python_drop(instance: *mut c_void) {
557    if !instance.is_null() {
558        let _ = unsafe { Box::from_raw(instance as *mut PythonRuntime) };
559    }
560}
561
562// ============================================================================
563// Helpers
564// ============================================================================
565
566/// Re-open libpython with RTLD_GLOBAL so its symbols are visible to C
567/// extensions (numpy, pandas, etc.) loaded later via Python's own dlopen.
568///
569/// We try common sonames in order. RTLD_NOLOAD ensures we only promote
570/// the copy already in memory — no new loading occurs.
571#[cfg(unix)]
572fn promote_libpython_symbols() {
573    const SONAMES: &[&[u8]] = &[
574        b"libpython3.13.so.1.0\0",
575        b"libpython3.13.so\0",
576        b"libpython3.12.so.1.0\0",
577        b"libpython3.12.so\0",
578        b"libpython3.11.so.1.0\0",
579        b"libpython3.11.so\0",
580        b"libpython3.so\0",
581    ];
582    for soname in SONAMES {
583        let handle = unsafe {
584            libc::dlopen(
585                soname.as_ptr() as *const std::ffi::c_char,
586                libc::RTLD_NOLOAD | libc::RTLD_NOW | libc::RTLD_GLOBAL,
587            )
588        };
589        if !handle.is_null() {
590            unsafe { libc::dlclose(handle) };
591            return;
592        }
593    }
594    // If none matched, fall through silently — basic Python works fine,
595    // only C extensions that reference libpython symbols will fail.
596}
597
598fn str_from_raw<'a>(ptr: *const u8, len: usize) -> Option<&'a str> {
599    if ptr.is_null() || len == 0 {
600        return None;
601    }
602    let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
603    std::str::from_utf8(slice).ok()
604}
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609
610    #[test]
611    fn lsp_config_exposes_pyright_defaults() {
612        let config = PythonRuntime::lsp_config();
613        assert_eq!(config.language_id, "python");
614        assert_eq!(
615            config.server_command,
616            vec!["pyright-langserver".to_string(), "--stdio".to_string()]
617        );
618        assert_eq!(config.file_extension, ".py");
619        assert!(config.extra_paths.is_empty());
620    }
621
622    #[test]
623    fn python_get_lsp_config_returns_valid_msgpack_payload() {
624        let mut out_ptr: *mut u8 = std::ptr::null_mut();
625        let mut out_len: usize = 0;
626
627        let code =
628            unsafe { python_get_lsp_config(std::ptr::null_mut(), &mut out_ptr, &mut out_len) };
629        assert_eq!(code, PluginError::Success as i32);
630        assert!(!out_ptr.is_null());
631        assert!(out_len > 0);
632
633        let bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len) };
634        let decoded: LanguageRuntimeLspConfig =
635            rmp_serde::from_slice(bytes).expect("payload should decode");
636        assert_eq!(decoded.language_id, "python");
637        assert_eq!(decoded.file_extension, ".py");
638
639        unsafe { python_free_buffer(out_ptr, out_len) };
640    }
641}