Skip to main content

rustpy_ml/
macro_embed.rs

1use pyo3::prelude::*;
2use pyo3::types::PyModule;
3use std::collections::HashMap;
4
5/// Execute Python code as a string and return the result
6/// This is the runtime function called by the python! macro
7pub fn run_python_code<T>(code: &str, locals: Option<HashMap<String, Py<PyAny>>>) -> crate::Result<T>
8where
9    T: for<'py> FromPyObject<'py>,
10{
11    crate::runtime::init()?;
12    
13    Python::with_gil(|py| {
14        // Always create a locals dict for consistent behavior
15        let locals_dict = pyo3::types::PyDict::new_bound(py);
16        
17        // Add globals first
18        let globals = crate::runtime::GLOBALS.get()
19            .ok_or_else(|| crate::Error::Interp("Runtime not initialized".to_string()))?;
20        let globals_map = globals.lock().unwrap();
21        for (key, value) in globals_map.iter() {
22            locals_dict.set_item(key, value)?;
23        }
24        drop(globals_map); // Release the lock
25        
26        // Track initial keys
27        let mut initial_keys = std::collections::HashSet::new();
28        for key in locals_dict.keys() {
29            if let Ok(k) = key.extract::<String>() {
30                initial_keys.insert(k);
31            }
32        }
33        
34        if let Some(locals_map) = &locals {
35            for (key, value) in locals_map {
36                initial_keys.insert(key.clone());
37                locals_dict.set_item(key, value)?;
38            }
39        }
40        
41        // Try to evaluate as expression first (wrapping in __result__)
42        let wrapped_code = format!("__result__ = {}", code.trim());
43        
44        let result = py.run_bound(&wrapped_code, None, Some(&locals_dict));
45            
46        match result {
47            Ok(_) => {
48                // Successfully evaluated as expression, get the result
49                let result_obj = locals_dict.get_item("__result__")?.ok_or_else(|| {
50                    pyo3::exceptions::PyKeyError::new_err("__result__ not found")
51                })?;
52                result_obj.extract::<T>().map_err(Into::into)
53            },
54            Err(_) => {
55                // Not an expression - try to execute as statements
56                py.run_bound(code, None, Some(&locals_dict))?;
57                
58                // After execution, look for 'result' variable first (convention)
59                if let Some(result) = locals_dict.get_item("result")? {
60                    // Update globals before returning
61                    let mut globals_map = globals.lock().unwrap();
62                    for (key, value) in locals_dict.iter() {
63                        if let Ok(k) = key.extract::<String>() {
64                            if !k.starts_with("__") && !k.ends_with("__") && !initial_keys.contains(&k) {
65                                globals_map.insert(k, value.into());
66                            }
67                        }
68                    }
69                    return result.extract::<T>().map_err(Into::into);
70                }
71                
72                // Look for the last non-dunder, non-module variable added during execution
73                let items: Vec<(String, pyo3::Bound<pyo3::PyAny>)> = locals_dict.items().extract()?;
74                for (key, value) in items.iter().rev() {
75                    // Skip builtins, private, and initial keys
76                    if key.starts_with("__") || key.starts_with("_") || initial_keys.contains(key) {
77                        continue;
78                    }
79                    
80                    // Skip modules (they have __name__ attribute)
81                    if value.hasattr("__name__").unwrap_or(false) && 
82                       value.hasattr("__package__").unwrap_or(false) {
83                        continue;
84                    }
85                    
86                    // Update globals before returning
87                    let mut globals_map = globals.lock().unwrap();
88                    for (k, v) in locals_dict.iter() {
89                        if let Ok(k_str) = k.extract::<String>() {
90                            if !k_str.starts_with("__") && !k_str.ends_with("__") && !initial_keys.contains(&k_str) {
91                                globals_map.insert(k_str, v.into());
92                            }
93                        }
94                    }
95                    return value.extract::<T>().map_err(Into::into);
96                }
97                
98                // Return None if no result found
99                py.None().extract::<T>(py).map_err(Into::into)
100            }
101        }
102    })
103}
104
105/// Execute Python code without returning a value
106pub fn exec_python_code(code: &str, locals: Option<HashMap<String, Py<PyAny>>>) -> crate::Result<()> {
107    crate::runtime::init()?;
108    
109    Python::with_gil(|py| {
110        // Create locals dictionary
111        let locals_dict = pyo3::types::PyDict::new_bound(py);
112        
113        // Add globals first
114        let globals = crate::runtime::GLOBALS.get()
115            .ok_or_else(|| crate::Error::Interp("Runtime not initialized".to_string()))?;
116        let globals_map = globals.lock().unwrap();
117        for (key, value) in globals_map.iter() {
118            locals_dict.set_item(key, value)?;
119        }
120        drop(globals_map); // Release the lock
121        
122        // Add provided locals to the dictionary (can override globals)
123        if let Some(locals_map) = locals {
124            for (key, value) in locals_map {
125                locals_dict.set_item(key, value)?;
126            }
127        }
128        
129        // Execute code with locals
130        py.run_bound(code, None, Some(&locals_dict))?;
131        
132        // Update globals with any new variables defined in locals
133        let mut globals_map = globals.lock().unwrap();
134        for (key, value) in locals_dict.iter() {
135            // Skip dunders and built-ins
136            if let Ok(key_str) = key.extract::<String>() {
137                if !key_str.starts_with("__") && !key_str.ends_with("__") {
138                    globals_map.insert(key_str, value.into());
139                }
140            }
141        }
142        
143        Ok(())
144    })
145}
146
147/// Simple macro to run Python code
148/// Usage: python!("print('Hello from Python')")
149#[macro_export]
150macro_rules! python {
151    // Simple case: just code, return PyObject
152    ($code:expr) => {{
153        $crate::macro_embed::run_python_code::<pyo3::PyObject>($code, None)
154    }};
155    
156    // With type annotation: python!(-> i32, "1 + 1")
157    (-> $ty:ty, $code:expr) => {{
158        $crate::macro_embed::run_python_code::<$ty>($code, None)
159    }};
160    
161    // With locals: python!({"x" => 5}, "x * 2")
162    ($locals:expr, $code:expr) => {{
163        $crate::macro_embed::run_python_code::<pyo3::PyObject>($code, Some($locals))
164    }};
165    
166    // With locals and type: python!({"x" => 5}, -> i32, "x * 2")
167    ($locals:expr, -> $ty:ty, $code:expr) => {{
168        $crate::macro_embed::run_python_code::<$ty>($code, Some($locals))
169    }};
170}
171
172/// Macro to execute Python code without returning a value
173/// Usage: py_exec!("import numpy as np\nprint(np.__version__)")
174#[macro_export]
175macro_rules! py_exec {
176    ($code:expr) => {{
177        $crate::macro_embed::exec_python_code($code, None)
178    }};
179    
180    ($locals:expr, $code:expr) => {{
181        $crate::macro_embed::exec_python_code($code, Some($locals))
182    }};
183}
184
185/// Macro to evaluate Python expression with type inference
186/// Usage: let result: i32 = py_eval!("1 + 1");
187#[macro_export]
188macro_rules! py_eval {
189    ($code:expr) => {{
190        $crate::runtime::eval($code)
191    }};
192}
193
194/// Helper macro to create a locals HashMap
195/// Usage: py_locals!{"x" => 5, "y" => 10}
196#[macro_export]
197macro_rules! py_locals {
198    ($($key:expr => $value:expr),* $(,)?) => {{
199        let mut map = std::collections::HashMap::new();
200        pyo3::Python::with_gil(|py| {
201            $(
202                map.insert($key.to_string(), pyo3::IntoPy::into_py($value, py));
203            )*
204        });
205        map
206    }};
207}