1use pyo3::exceptions::{PyRuntimeError, PyValueError};
35use pyo3::prelude::*;
36use scirs2_numpy::{PyReadonlyArray1, PyReadonlyArray2};
37
38use scirs2_symbolic::eml::eval::{eval_real as rust_eval_real, EvalCtx};
39use scirs2_symbolic::eml::{
40 grad as rust_grad, lower as rust_lower, simplify_op as rust_simplify_op,
41 Canonical as RustCanonical, EmlTree as RustEmlTree, LoweredOp as RustLoweredOp,
42};
43use scirs2_symbolic::regression::{discover as rust_discover, SrConfig as RustSrConfig};
44
45#[pyclass(name = "EmlTree", module = "scirs2.symbolic", skip_from_py_object)]
51#[derive(Clone)]
52pub struct PyEmlTree {
53 inner: RustEmlTree,
54}
55
56#[pymethods]
57impl PyEmlTree {
58 #[staticmethod]
60 fn one() -> Self {
61 Self {
62 inner: RustEmlTree::one(),
63 }
64 }
65
66 #[staticmethod]
68 fn var(idx: usize) -> Self {
69 Self {
70 inner: RustEmlTree::var(idx),
71 }
72 }
73
74 #[staticmethod]
76 fn eml(left: &Self, right: &Self) -> Self {
77 Self {
78 inner: RustEmlTree::eml(&left.inner, &right.inner),
79 }
80 }
81
82 fn depth(&self) -> usize {
84 self.inner.depth()
85 }
86
87 fn size(&self) -> usize {
89 self.inner.size()
90 }
91
92 fn num_vars(&self) -> usize {
94 self.inner.num_vars()
95 }
96
97 fn structural_hash(&self) -> (u64, u64) {
100 let h = self.inner.structural_hash();
101 ((h >> 64) as u64, (h & 0xFFFF_FFFF_FFFF_FFFF) as u64)
102 }
103
104 fn __repr__(&self) -> String {
105 format!(
106 "EmlTree(depth={}, size={}, num_vars={})",
107 self.depth(),
108 self.size(),
109 self.num_vars()
110 )
111 }
112}
113
114#[pyclass(name = "Canonical", module = "scirs2.symbolic")]
123pub struct PyCanonical;
124
125#[pymethods]
126impl PyCanonical {
127 #[staticmethod]
130 fn exp(x: &PyEmlTree) -> PyEmlTree {
131 PyEmlTree {
132 inner: RustCanonical::exp(&x.inner),
133 }
134 }
135 #[staticmethod]
137 fn ln(x: &PyEmlTree) -> PyEmlTree {
138 PyEmlTree {
139 inner: RustCanonical::ln(&x.inner),
140 }
141 }
142 #[staticmethod]
144 fn euler() -> PyEmlTree {
145 PyEmlTree {
146 inner: RustCanonical::euler(),
147 }
148 }
149 #[staticmethod]
151 fn pi() -> PyEmlTree {
152 PyEmlTree {
153 inner: RustCanonical::pi(),
154 }
155 }
156 #[staticmethod]
158 fn neg(x: &PyEmlTree) -> PyEmlTree {
159 PyEmlTree {
160 inner: RustCanonical::neg(&x.inner),
161 }
162 }
163
164 #[staticmethod]
167 fn add(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
168 PyEmlTree {
169 inner: RustCanonical::add(&a.inner, &b.inner),
170 }
171 }
172 #[staticmethod]
174 fn sub(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
175 PyEmlTree {
176 inner: RustCanonical::sub(&a.inner, &b.inner),
177 }
178 }
179 #[staticmethod]
181 fn mul(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
182 PyEmlTree {
183 inner: RustCanonical::mul(&a.inner, &b.inner),
184 }
185 }
186 #[staticmethod]
188 fn div(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
189 PyEmlTree {
190 inner: RustCanonical::div(&a.inner, &b.inner),
191 }
192 }
193 #[staticmethod]
195 fn pow(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
196 PyEmlTree {
197 inner: RustCanonical::pow(&a.inner, &b.inner),
198 }
199 }
200
201 #[staticmethod]
204 fn sin(x: &PyEmlTree) -> PyEmlTree {
205 PyEmlTree {
206 inner: RustCanonical::sin(&x.inner),
207 }
208 }
209 #[staticmethod]
211 fn cos(x: &PyEmlTree) -> PyEmlTree {
212 PyEmlTree {
213 inner: RustCanonical::cos(&x.inner),
214 }
215 }
216 #[staticmethod]
218 fn tan(x: &PyEmlTree) -> PyEmlTree {
219 PyEmlTree {
220 inner: RustCanonical::tan(&x.inner),
221 }
222 }
223
224 #[staticmethod]
227 fn arcsin(x: &PyEmlTree) -> PyEmlTree {
228 PyEmlTree {
229 inner: RustCanonical::arcsin(&x.inner),
230 }
231 }
232 #[staticmethod]
234 fn arccos(x: &PyEmlTree) -> PyEmlTree {
235 PyEmlTree {
236 inner: RustCanonical::arccos(&x.inner),
237 }
238 }
239 #[staticmethod]
241 fn arctan(x: &PyEmlTree) -> PyEmlTree {
242 PyEmlTree {
243 inner: RustCanonical::arctan(&x.inner),
244 }
245 }
246
247 #[staticmethod]
250 fn sinh(x: &PyEmlTree) -> PyEmlTree {
251 PyEmlTree {
252 inner: RustCanonical::sinh(&x.inner),
253 }
254 }
255 #[staticmethod]
257 fn cosh(x: &PyEmlTree) -> PyEmlTree {
258 PyEmlTree {
259 inner: RustCanonical::cosh(&x.inner),
260 }
261 }
262 #[staticmethod]
264 fn tanh(x: &PyEmlTree) -> PyEmlTree {
265 PyEmlTree {
266 inner: RustCanonical::tanh(&x.inner),
267 }
268 }
269
270 #[staticmethod]
273 fn arcsinh(x: &PyEmlTree) -> PyEmlTree {
274 PyEmlTree {
275 inner: RustCanonical::arcsinh(&x.inner),
276 }
277 }
278 #[staticmethod]
280 fn arccosh(x: &PyEmlTree) -> PyEmlTree {
281 PyEmlTree {
282 inner: RustCanonical::arccosh(&x.inner),
283 }
284 }
285 #[staticmethod]
287 fn arctanh(x: &PyEmlTree) -> PyEmlTree {
288 PyEmlTree {
289 inner: RustCanonical::arctanh(&x.inner),
290 }
291 }
292
293 #[staticmethod]
296 fn sqrt(x: &PyEmlTree) -> PyEmlTree {
297 PyEmlTree {
298 inner: RustCanonical::sqrt(&x.inner),
299 }
300 }
301 #[staticmethod]
303 fn abs(x: &PyEmlTree) -> PyEmlTree {
304 PyEmlTree {
305 inner: RustCanonical::abs(&x.inner),
306 }
307 }
308 #[staticmethod]
310 fn square(x: &PyEmlTree) -> PyEmlTree {
311 PyEmlTree {
312 inner: RustCanonical::square(&x.inner),
313 }
314 }
315 #[staticmethod]
317 fn reciprocal(x: &PyEmlTree) -> PyEmlTree {
318 PyEmlTree {
319 inner: RustCanonical::reciprocal(&x.inner),
320 }
321 }
322
323 #[staticmethod]
327 fn nat(n: u64) -> PyResult<PyEmlTree> {
328 RustCanonical::nat(n)
329 .map(|t| PyEmlTree { inner: t })
330 .map_err(|e| PyValueError::new_err(e.to_string()))
331 }
332
333 #[staticmethod]
335 fn zero() -> PyEmlTree {
336 PyEmlTree {
337 inner: RustCanonical::zero(),
338 }
339 }
340
341 #[staticmethod]
343 fn neg_one() -> PyEmlTree {
344 PyEmlTree {
345 inner: RustCanonical::neg_one(),
346 }
347 }
348
349 #[staticmethod]
351 fn imag_unit() -> PyEmlTree {
352 PyEmlTree {
353 inner: RustCanonical::imag_unit(),
354 }
355 }
356}
357
358#[pyclass(name = "LoweredOp", module = "scirs2.symbolic", skip_from_py_object)]
365#[derive(Clone)]
366pub struct PyLoweredOp {
367 inner: RustLoweredOp,
368}
369
370#[pymethods]
371impl PyLoweredOp {
372 fn count_vars(&self) -> usize {
374 self.inner.count_vars()
375 }
376
377 fn structural_hash(&self) -> (u64, u64) {
379 let h = self.inner.structural_hash();
380 ((h >> 64) as u64, (h & 0xFFFF_FFFF_FFFF_FFFF) as u64)
381 }
382
383 fn __repr__(&self) -> String {
384 format!("LoweredOp(count_vars={})", self.count_vars())
385 }
386}
387
388#[pyfunction]
394fn lower(tree: &PyEmlTree) -> PyLoweredOp {
395 PyLoweredOp {
396 inner: rust_lower(&tree.inner),
397 }
398}
399
400#[pyfunction]
402fn simplify(op: &PyLoweredOp) -> PyLoweredOp {
403 PyLoweredOp {
404 inner: rust_simplify_op(&op.inner),
405 }
406}
407
408#[pyfunction]
410fn grad(op: &PyLoweredOp, wrt: usize) -> PyLoweredOp {
411 PyLoweredOp {
412 inner: rust_grad(&op.inner, wrt),
413 }
414}
415
416#[pyfunction]
421fn eval_real(op: &PyLoweredOp, vars: Vec<f64>) -> PyResult<f64> {
422 let ctx = EvalCtx::new(&vars);
423 rust_eval_real(&op.inner, &ctx).map_err(|e| PyRuntimeError::new_err(e.to_string()))
424}
425
426#[pyfunction]
437#[pyo3(signature = (
438 features,
439 targets,
440 max_iter = 50,
441 top_n = 3,
442 beam_width = 32,
443 max_depth = 6,
444 max_nodes = 20,
445))]
446#[allow(clippy::too_many_arguments)]
447fn discover(
448 py: Python<'_>,
449 features: PyReadonlyArray2<f64>,
450 targets: PyReadonlyArray1<f64>,
451 max_iter: usize,
452 top_n: usize,
453 beam_width: usize,
454 max_depth: usize,
455 max_nodes: usize,
456) -> PyResult<Vec<PyDiscoveredFormula>> {
457 let features_arr = features.as_array();
458 let targets_arr = targets.as_array();
459
460 let config = RustSrConfig::default()
461 .with_max_iter(max_iter)
462 .with_top_n(top_n)
463 .with_beam_width(beam_width)
464 .with_max_depth(max_depth)
465 .with_max_nodes(max_nodes);
466
467 let results = py.detach(|| rust_discover(features_arr, targets_arr, &config));
469
470 Ok(results
471 .into_iter()
472 .map(|f| PyDiscoveredFormula {
473 op: PyLoweredOp { inner: f.op },
474 mse: f.fitness.mse,
475 r_squared: f.fitness.r_squared,
476 combined: f.fitness.combined,
477 node_count: f.node_count,
478 n_vars: f.n_vars,
479 })
480 .collect())
481}
482
483#[pyclass(
485 name = "DiscoveredFormula",
486 module = "scirs2.symbolic",
487 skip_from_py_object
488)]
489#[derive(Clone)]
490pub struct PyDiscoveredFormula {
491 #[pyo3(get)]
493 pub op: PyLoweredOp,
494 #[pyo3(get)]
496 pub mse: f64,
497 #[pyo3(get)]
499 pub r_squared: f64,
500 #[pyo3(get)]
502 pub combined: f64,
503 #[pyo3(get)]
505 pub node_count: usize,
506 #[pyo3(get)]
508 pub n_vars: usize,
509}
510
511#[pymethods]
512impl PyDiscoveredFormula {
513 fn __repr__(&self) -> String {
514 format!(
515 "DiscoveredFormula(mse={:.6}, r_squared={:.6}, n_nodes={}, n_vars={})",
516 self.mse, self.r_squared, self.node_count, self.n_vars
517 )
518 }
519}
520
521pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
527 let py = m.py();
528 let symbolic = PyModule::new(py, "symbolic")?;
529
530 symbolic.add_class::<PyEmlTree>()?;
531 symbolic.add_class::<PyCanonical>()?;
532 symbolic.add_class::<PyLoweredOp>()?;
533 symbolic.add_class::<PyDiscoveredFormula>()?;
534
535 symbolic.add_function(wrap_pyfunction!(lower, &symbolic)?)?;
536 symbolic.add_function(wrap_pyfunction!(simplify, &symbolic)?)?;
537 symbolic.add_function(wrap_pyfunction!(grad, &symbolic)?)?;
538 symbolic.add_function(wrap_pyfunction!(eval_real, &symbolic)?)?;
539 symbolic.add_function(wrap_pyfunction!(discover, &symbolic)?)?;
540
541 symbolic.add(
542 "__doc__",
543 "Symbolic mathematics — EML substrate, evaluation, gradient, and \
544 beam-search symbolic regression.\n\nClasses:\n - EmlTree: uniform \
545 binary EML tree (constant 1 + var leaves + binary eml nodes).\n - \
546 Canonical: namespace of elementary-function constructors.\n - \
547 LoweredOp: flat operator IR produced by lower(tree).\n - \
548 DiscoveredFormula: result of discover().\n\nFunctions:\n - \
549 lower(tree) -> LoweredOp\n - simplify(op) -> LoweredOp\n - grad(op, wrt) \
550 -> LoweredOp\n - eval_real(op, vars) -> float\n - discover(features, \
551 targets, ...) -> list[DiscoveredFormula]",
552 )?;
553
554 m.add_submodule(&symbolic)?;
555 Ok(())
556}