1use pyo3::prelude::*;
9use scirs2_autograd::variable::NamespaceTrait;
10use scirs2_autograd::VariableEnvironment;
11use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13#[pyclass(name = "VariableEnvironment", unsendable)]
27pub struct PyVariableEnvironment {
28 inner: VariableEnvironment<f64>,
29}
30
31#[pymethods]
32impl PyVariableEnvironment {
33 #[new]
35 fn new() -> Self {
36 Self {
37 inner: VariableEnvironment::new(),
38 }
39 }
40
41 #[allow(deprecated)]
47 fn set_variable(&mut self, name: &str, array: &Bound<'_, PyAny>) -> PyResult<()> {
48 if let Ok(arr1d) = array.downcast::<PyArray1<f64>>() {
50 let binding = arr1d.readonly();
51 let data = binding.as_array().to_owned();
52 self.inner.name(name).set(data);
53 return Ok(());
54 }
55
56 if let Ok(arr2d) = array.downcast::<PyArray2<f64>>() {
58 let binding = arr2d.readonly();
59 let data = binding.as_array().to_owned();
60 self.inner.name(name).set(data);
61 return Ok(());
62 }
63
64 Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
65 "Array must be 1D or 2D float64 numpy array",
66 ))
67 }
68
69 #[allow(deprecated)]
77 fn get_variable(&self, py: Python, name: &str) -> PyResult<Py<PyAny>> {
78 let namespace = self.inner.default_namespace();
79 let array_ref = namespace.get_array_by_name(name).ok_or_else(|| {
80 PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Variable '{}' not found", name))
81 })?;
82
83 let array = array_ref.borrow();
84
85 match array.ndim() {
87 1 => {
88 let arr1d = array
89 .view()
90 .into_dimensionality::<scirs2_core::ndarray::Ix1>()
91 .map_err(|e| {
92 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
93 "Dimension error: {}",
94 e
95 ))
96 })?;
97 Ok(arr1d.to_owned().into_pyarray(py).unbind().into())
98 }
99 2 => {
100 let arr2d = array
101 .view()
102 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
103 .map_err(|e| {
104 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
105 "Dimension error: {}",
106 e
107 ))
108 })?;
109 Ok(arr2d.to_owned().into_pyarray(py).unbind().into())
110 }
111 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
112 "Only 1D and 2D arrays are currently supported",
113 )),
114 }
115 }
116
117 fn list_variables(&self) -> Vec<String> {
122 self.inner
123 .default_namespace()
124 .current_var_names()
125 .into_iter()
126 .map(|s: &str| s.to_string())
127 .collect()
128 }
129
130 fn save(&self, path: &str) -> PyResult<()> {
135 self.inner
136 .save(path)
137 .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Save error: {}", e)))
138 }
139
140 #[staticmethod]
148 fn load(path: &str) -> PyResult<Self> {
149 let inner = VariableEnvironment::<f64>::load(path).map_err(|e| {
150 PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Load error: {}", e))
151 })?;
152
153 Ok(Self { inner })
154 }
155
156 fn __len__(&self) -> usize {
158 self.inner.default_namespace().current_var_names().len()
159 }
160
161 fn __repr__(&self) -> String {
163 format!(
164 "VariableEnvironment({} variables)",
165 self.inner.default_namespace().current_var_names().len()
166 )
167 }
168}
169
170pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
175 m.add_class::<PyVariableEnvironment>()?;
177
178 m.add("__doc__", "Automatic differentiation and variable management\n\n\
180 This module provides model parameter storage and persistence from scirs2-autograd.\n\
181 Due to Rust lifetime complexities in the computational graph API, full autodiff\n\
182 functionality is not exposed. For neural network training, we recommend:\n\n\
183 - PyTorch: Industry-standard deep learning framework\n\
184 - TensorFlow: Comprehensive ML platform\n\n\
185 scirs2 arrays are NumPy-compatible, enabling seamless integration with these frameworks.\n\n\
186 Use VariableEnvironment for:\n\
187 - Storing and managing model parameters\n\
188 - Saving/loading trained model weights\n\
189 - Transferring weights between scirs2 and PyTorch/TensorFlow")?;
190
191 Ok(())
192}