Skip to main content

scirs2/
symbolic.rs

1//! Python bindings for scirs2-symbolic.
2//!
3//! Exposes the EML substrate (`EmlTree`, `Canonical`, `LoweredOp`),
4//! evaluation (`eval_real`), gradient (`grad`), and the symbolic-regression
5//! API (`discover`) under the Python sub-namespace `scirs2.symbolic`.
6//!
7//! # Example (Python)
8//!
9//! ```python
10//! import scirs2 as s2
11//! import numpy as np
12//!
13//! # Build an EML tree: sin(x²)
14//! x = s2.symbolic.EmlTree.var(0)
15//! formula = s2.symbolic.Canonical.sin(s2.symbolic.Canonical.mul(x, x))
16//! lowered = s2.symbolic.lower(formula)
17//!
18//! # Evaluate at x = 0.5
19//! result = s2.symbolic.eval_real(lowered, [0.5])
20//! print(result)  # ~0.247
21//!
22//! # Symbolic gradient with respect to variable 0
23//! grad_op = s2.symbolic.grad(lowered, 0)
24//!
25//! # Symbolic regression
26//! features = np.array([[1.0], [2.0], [3.0]])
27//! targets = np.array([1.0, 4.0, 9.0])
28//! results = s2.symbolic.discover(features, targets)
29//! ```
30//!
31//! Note: `Canonical::sin` produces a 543-node-deep canonical EML tree;
32//! evaluation is iterative (no stack blowup).
33
34use pyo3::exceptions::{PyRuntimeError, PyValueError};
35use pyo3::prelude::*;
36use scirs2_numpy::{PyReadonlyArray1, PyReadonlyArray2};
37
38use scirs2_symbolic::eml::eval::{eval_real as rust_eval_real, EvalCtx};
39use scirs2_symbolic::eml::{
40    grad as rust_grad, lower as rust_lower, simplify_op as rust_simplify_op,
41    Canonical as RustCanonical, EmlTree as RustEmlTree, LoweredOp as RustLoweredOp,
42};
43use scirs2_symbolic::regression::{discover as rust_discover, SrConfig as RustSrConfig};
44
45// ============================================================================
46// EmlTree wrapper
47// ============================================================================
48
49/// Python wrapper for [`scirs2_symbolic::eml::EmlTree`].
50#[pyclass(name = "EmlTree", module = "scirs2.symbolic", skip_from_py_object)]
51#[derive(Clone)]
52pub struct PyEmlTree {
53    inner: RustEmlTree,
54}
55
56#[pymethods]
57impl PyEmlTree {
58    /// Construct the constant `1` — the only EML leaf.
59    #[staticmethod]
60    fn one() -> Self {
61        Self {
62            inner: RustEmlTree::one(),
63        }
64    }
65
66    /// Construct a variable at index `idx`.
67    #[staticmethod]
68    fn var(idx: usize) -> Self {
69        Self {
70            inner: RustEmlTree::var(idx),
71        }
72    }
73
74    /// Construct `eml(left, right) = exp(left) - ln(right)`.
75    #[staticmethod]
76    fn eml(left: &Self, right: &Self) -> Self {
77        Self {
78            inner: RustEmlTree::eml(&left.inner, &right.inner),
79        }
80    }
81
82    /// Tree depth.
83    fn depth(&self) -> usize {
84        self.inner.depth()
85    }
86
87    /// Total node count.
88    fn size(&self) -> usize {
89        self.inner.size()
90    }
91
92    /// Number of distinct variables (max var index + 1, or 0 if none).
93    fn num_vars(&self) -> usize {
94        self.inner.num_vars()
95    }
96
97    /// Structural hash returned as `(high_u64, low_u64)`, since Python lacks
98    /// a native `u128` type.
99    fn structural_hash(&self) -> (u64, u64) {
100        let h = self.inner.structural_hash();
101        ((h >> 64) as u64, (h & 0xFFFF_FFFF_FFFF_FFFF) as u64)
102    }
103
104    fn __repr__(&self) -> String {
105        format!(
106            "EmlTree(depth={}, size={}, num_vars={})",
107            self.depth(),
108            self.size(),
109            self.num_vars()
110        )
111    }
112}
113
114// ============================================================================
115// Canonical namespace
116// ============================================================================
117
118/// Namespace for canonical EML constructors.
119///
120/// Mirrors `scirs2_symbolic::eml::Canonical`. Every method returns a
121/// canonical [`PyEmlTree`] for the named elementary function.
122#[pyclass(name = "Canonical", module = "scirs2.symbolic")]
123pub struct PyCanonical;
124
125#[pymethods]
126impl PyCanonical {
127    // ----- Basic operations -----
128    /// `exp(x)`.
129    #[staticmethod]
130    fn exp(x: &PyEmlTree) -> PyEmlTree {
131        PyEmlTree {
132            inner: RustCanonical::exp(&x.inner),
133        }
134    }
135    /// `ln(x)`.
136    #[staticmethod]
137    fn ln(x: &PyEmlTree) -> PyEmlTree {
138        PyEmlTree {
139            inner: RustCanonical::ln(&x.inner),
140        }
141    }
142    /// Euler's number `e`.
143    #[staticmethod]
144    fn euler() -> PyEmlTree {
145        PyEmlTree {
146            inner: RustCanonical::euler(),
147        }
148    }
149    /// `pi` (encoded such that complex evaluation yields `iπ`).
150    #[staticmethod]
151    fn pi() -> PyEmlTree {
152        PyEmlTree {
153            inner: RustCanonical::pi(),
154        }
155    }
156    /// Negation `-x`.
157    #[staticmethod]
158    fn neg(x: &PyEmlTree) -> PyEmlTree {
159        PyEmlTree {
160            inner: RustCanonical::neg(&x.inner),
161        }
162    }
163
164    // ----- Arithmetic -----
165    /// `a + b`.
166    #[staticmethod]
167    fn add(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
168        PyEmlTree {
169            inner: RustCanonical::add(&a.inner, &b.inner),
170        }
171    }
172    /// `a - b`.
173    #[staticmethod]
174    fn sub(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
175        PyEmlTree {
176            inner: RustCanonical::sub(&a.inner, &b.inner),
177        }
178    }
179    /// `a * b`.
180    #[staticmethod]
181    fn mul(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
182        PyEmlTree {
183            inner: RustCanonical::mul(&a.inner, &b.inner),
184        }
185    }
186    /// `a / b`.
187    #[staticmethod]
188    fn div(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
189        PyEmlTree {
190            inner: RustCanonical::div(&a.inner, &b.inner),
191        }
192    }
193    /// `a ^ b` (power).
194    #[staticmethod]
195    fn pow(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
196        PyEmlTree {
197            inner: RustCanonical::pow(&a.inner, &b.inner),
198        }
199    }
200
201    // ----- Trig -----
202    /// `sin(x)`.
203    #[staticmethod]
204    fn sin(x: &PyEmlTree) -> PyEmlTree {
205        PyEmlTree {
206            inner: RustCanonical::sin(&x.inner),
207        }
208    }
209    /// `cos(x)`.
210    #[staticmethod]
211    fn cos(x: &PyEmlTree) -> PyEmlTree {
212        PyEmlTree {
213            inner: RustCanonical::cos(&x.inner),
214        }
215    }
216    /// `tan(x)`.
217    #[staticmethod]
218    fn tan(x: &PyEmlTree) -> PyEmlTree {
219        PyEmlTree {
220            inner: RustCanonical::tan(&x.inner),
221        }
222    }
223
224    // ----- Inverse trig -----
225    /// `arcsin(x)`.
226    #[staticmethod]
227    fn arcsin(x: &PyEmlTree) -> PyEmlTree {
228        PyEmlTree {
229            inner: RustCanonical::arcsin(&x.inner),
230        }
231    }
232    /// `arccos(x)`.
233    #[staticmethod]
234    fn arccos(x: &PyEmlTree) -> PyEmlTree {
235        PyEmlTree {
236            inner: RustCanonical::arccos(&x.inner),
237        }
238    }
239    /// `arctan(x)`.
240    #[staticmethod]
241    fn arctan(x: &PyEmlTree) -> PyEmlTree {
242        PyEmlTree {
243            inner: RustCanonical::arctan(&x.inner),
244        }
245    }
246
247    // ----- Hyperbolic -----
248    /// `sinh(x)`.
249    #[staticmethod]
250    fn sinh(x: &PyEmlTree) -> PyEmlTree {
251        PyEmlTree {
252            inner: RustCanonical::sinh(&x.inner),
253        }
254    }
255    /// `cosh(x)`.
256    #[staticmethod]
257    fn cosh(x: &PyEmlTree) -> PyEmlTree {
258        PyEmlTree {
259            inner: RustCanonical::cosh(&x.inner),
260        }
261    }
262    /// `tanh(x)`.
263    #[staticmethod]
264    fn tanh(x: &PyEmlTree) -> PyEmlTree {
265        PyEmlTree {
266            inner: RustCanonical::tanh(&x.inner),
267        }
268    }
269
270    // ----- Inverse hyperbolic -----
271    /// `arcsinh(x)`.
272    #[staticmethod]
273    fn arcsinh(x: &PyEmlTree) -> PyEmlTree {
274        PyEmlTree {
275            inner: RustCanonical::arcsinh(&x.inner),
276        }
277    }
278    /// `arccosh(x)`.
279    #[staticmethod]
280    fn arccosh(x: &PyEmlTree) -> PyEmlTree {
281        PyEmlTree {
282            inner: RustCanonical::arccosh(&x.inner),
283        }
284    }
285    /// `arctanh(x)`.
286    #[staticmethod]
287    fn arctanh(x: &PyEmlTree) -> PyEmlTree {
288        PyEmlTree {
289            inner: RustCanonical::arctanh(&x.inner),
290        }
291    }
292
293    // ----- Powers, roots, abs -----
294    /// `sqrt(x)`.
295    #[staticmethod]
296    fn sqrt(x: &PyEmlTree) -> PyEmlTree {
297        PyEmlTree {
298            inner: RustCanonical::sqrt(&x.inner),
299        }
300    }
301    /// `|x|`.
302    #[staticmethod]
303    fn abs(x: &PyEmlTree) -> PyEmlTree {
304        PyEmlTree {
305            inner: RustCanonical::abs(&x.inner),
306        }
307    }
308    /// `x²`.
309    #[staticmethod]
310    fn square(x: &PyEmlTree) -> PyEmlTree {
311        PyEmlTree {
312            inner: RustCanonical::square(&x.inner),
313        }
314    }
315    /// `1 / x`.
316    #[staticmethod]
317    fn reciprocal(x: &PyEmlTree) -> PyEmlTree {
318        PyEmlTree {
319            inner: RustCanonical::reciprocal(&x.inner),
320        }
321    }
322
323    // ----- Constants -----
324    /// Natural number `n >= 1`. Raises `ValueError` on `n == 0`
325    /// (use `zero()` for the additive identity).
326    #[staticmethod]
327    fn nat(n: u64) -> PyResult<PyEmlTree> {
328        RustCanonical::nat(n)
329            .map(|t| PyEmlTree { inner: t })
330            .map_err(|e| PyValueError::new_err(e.to_string()))
331    }
332
333    /// Additive identity `0`.
334    #[staticmethod]
335    fn zero() -> PyEmlTree {
336        PyEmlTree {
337            inner: RustCanonical::zero(),
338        }
339    }
340
341    /// Negative one `-1`.
342    #[staticmethod]
343    fn neg_one() -> PyEmlTree {
344        PyEmlTree {
345            inner: RustCanonical::neg_one(),
346        }
347    }
348
349    /// Imaginary unit `i = exp(iπ/2)` (purely imaginary; `eval_real` errors).
350    #[staticmethod]
351    fn imag_unit() -> PyEmlTree {
352        PyEmlTree {
353            inner: RustCanonical::imag_unit(),
354        }
355    }
356}
357
358// ============================================================================
359// LoweredOp wrapper
360// ============================================================================
361
362/// Python wrapper for [`scirs2_symbolic::eml::LoweredOp`] — the flat
363/// operator IR produced by `lower`.
364#[pyclass(name = "LoweredOp", module = "scirs2.symbolic", skip_from_py_object)]
365#[derive(Clone)]
366pub struct PyLoweredOp {
367    inner: RustLoweredOp,
368}
369
370#[pymethods]
371impl PyLoweredOp {
372    /// Number of distinct variables (max var index + 1, or 0 if none).
373    fn count_vars(&self) -> usize {
374        self.inner.count_vars()
375    }
376
377    /// Structural hash returned as `(high_u64, low_u64)`.
378    fn structural_hash(&self) -> (u64, u64) {
379        let h = self.inner.structural_hash();
380        ((h >> 64) as u64, (h & 0xFFFF_FFFF_FFFF_FFFF) as u64)
381    }
382
383    fn __repr__(&self) -> String {
384        format!("LoweredOp(count_vars={})", self.count_vars())
385    }
386}
387
388// ============================================================================
389// Top-level functions
390// ============================================================================
391
392/// Lower an [`PyEmlTree`] to a [`PyLoweredOp`].
393#[pyfunction]
394fn lower(tree: &PyEmlTree) -> PyLoweredOp {
395    PyLoweredOp {
396        inner: rust_lower(&tree.inner),
397    }
398}
399
400/// Algebraically simplify a [`PyLoweredOp`].
401#[pyfunction]
402fn simplify(op: &PyLoweredOp) -> PyLoweredOp {
403    PyLoweredOp {
404        inner: rust_simplify_op(&op.inner),
405    }
406}
407
408/// Symbolic gradient with respect to variable `wrt`.
409#[pyfunction]
410fn grad(op: &PyLoweredOp, wrt: usize) -> PyLoweredOp {
411    PyLoweredOp {
412        inner: rust_grad(&op.inner, wrt),
413    }
414}
415
416/// Evaluate a [`PyLoweredOp`] at the given real variable values.
417///
418/// `vars[i]` binds variable index `i`. Raises `RuntimeError` on
419/// numerical-domain failures (e.g. `ln(0)`).
420#[pyfunction]
421fn eval_real(op: &PyLoweredOp, vars: Vec<f64>) -> PyResult<f64> {
422    let ctx = EvalCtx::new(&vars);
423    rust_eval_real(&op.inner, &ctx).map_err(|e| PyRuntimeError::new_err(e.to_string()))
424}
425
426// ============================================================================
427// Symbolic regression
428// ============================================================================
429
430/// Beam-search symbolic regression — discovers formulas approximating
431/// `targets ≈ f(features)`.
432///
433/// `features` has shape `(n_samples, n_features)`; `targets` has shape
434/// `(n_samples,)`. Returns up to `top_n` formulas, ranked by combined
435/// fitness (lower is better).
436#[pyfunction]
437#[pyo3(signature = (
438    features,
439    targets,
440    max_iter = 50,
441    top_n = 3,
442    beam_width = 32,
443    max_depth = 6,
444    max_nodes = 20,
445))]
446#[allow(clippy::too_many_arguments)]
447fn discover(
448    py: Python<'_>,
449    features: PyReadonlyArray2<f64>,
450    targets: PyReadonlyArray1<f64>,
451    max_iter: usize,
452    top_n: usize,
453    beam_width: usize,
454    max_depth: usize,
455    max_nodes: usize,
456) -> PyResult<Vec<PyDiscoveredFormula>> {
457    let features_arr = features.as_array();
458    let targets_arr = targets.as_array();
459
460    let config = RustSrConfig::default()
461        .with_max_iter(max_iter)
462        .with_top_n(top_n)
463        .with_beam_width(beam_width)
464        .with_max_depth(max_depth)
465        .with_max_nodes(max_nodes);
466
467    // Release the GIL while running the beam search (PyO3 0.28: detach == old allow_threads).
468    let results = py.detach(|| rust_discover(features_arr, targets_arr, &config));
469
470    Ok(results
471        .into_iter()
472        .map(|f| PyDiscoveredFormula {
473            op: PyLoweredOp { inner: f.op },
474            mse: f.fitness.mse,
475            r_squared: f.fitness.r_squared,
476            combined: f.fitness.combined,
477            node_count: f.node_count,
478            n_vars: f.n_vars,
479        })
480        .collect())
481}
482
483/// Python view of a discovered formula returned by `discover`.
484#[pyclass(
485    name = "DiscoveredFormula",
486    module = "scirs2.symbolic",
487    skip_from_py_object
488)]
489#[derive(Clone)]
490pub struct PyDiscoveredFormula {
491    /// The lowered operator IR.
492    #[pyo3(get)]
493    pub op: PyLoweredOp,
494    /// Mean-squared error on the training data.
495    #[pyo3(get)]
496    pub mse: f64,
497    /// Coefficient of determination `R²`.
498    #[pyo3(get)]
499    pub r_squared: f64,
500    /// Combined fitness (MSE + parsimony penalty); lower is better.
501    #[pyo3(get)]
502    pub combined: f64,
503    /// Total node count of the formula.
504    #[pyo3(get)]
505    pub node_count: usize,
506    /// Number of distinct variables used.
507    #[pyo3(get)]
508    pub n_vars: usize,
509}
510
511#[pymethods]
512impl PyDiscoveredFormula {
513    fn __repr__(&self) -> String {
514        format!(
515            "DiscoveredFormula(mse={:.6}, r_squared={:.6}, n_nodes={}, n_vars={})",
516            self.mse, self.r_squared, self.node_count, self.n_vars
517        )
518    }
519}
520
521// ============================================================================
522// Module registration
523// ============================================================================
524
525/// Register the `symbolic` sub-namespace on the parent `scirs2` module.
526pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
527    let py = m.py();
528    let symbolic = PyModule::new(py, "symbolic")?;
529
530    symbolic.add_class::<PyEmlTree>()?;
531    symbolic.add_class::<PyCanonical>()?;
532    symbolic.add_class::<PyLoweredOp>()?;
533    symbolic.add_class::<PyDiscoveredFormula>()?;
534
535    symbolic.add_function(wrap_pyfunction!(lower, &symbolic)?)?;
536    symbolic.add_function(wrap_pyfunction!(simplify, &symbolic)?)?;
537    symbolic.add_function(wrap_pyfunction!(grad, &symbolic)?)?;
538    symbolic.add_function(wrap_pyfunction!(eval_real, &symbolic)?)?;
539    symbolic.add_function(wrap_pyfunction!(discover, &symbolic)?)?;
540
541    symbolic.add(
542        "__doc__",
543        "Symbolic mathematics — EML substrate, evaluation, gradient, and \
544         beam-search symbolic regression.\n\nClasses:\n  - EmlTree: uniform \
545         binary EML tree (constant 1 + var leaves + binary eml nodes).\n  - \
546         Canonical: namespace of elementary-function constructors.\n  - \
547         LoweredOp: flat operator IR produced by lower(tree).\n  - \
548         DiscoveredFormula: result of discover().\n\nFunctions:\n  - \
549         lower(tree) -> LoweredOp\n  - simplify(op) -> LoweredOp\n  - grad(op, wrt) \
550         -> LoweredOp\n  - eval_real(op, vars) -> float\n  - discover(features, \
551         targets, ...) -> list[DiscoveredFormula]",
552    )?;
553
554    m.add_submodule(&symbolic)?;
555    Ok(())
556}