Skip to main content

scirs2/
autograd.rs

1//! Python bindings for scirs2-autograd
2//!
3//! This module provides Python bindings for variable management and model persistence
4//! from scirs2-autograd. Due to Rust lifetime complexities, full computational graph
5//! APIs are not exposed. For comprehensive automatic differentiation, use PyTorch
6//! or TensorFlow which integrate seamlessly with scirs2 via NumPy arrays.
7
8use pyo3::prelude::*;
9use scirs2_autograd::variable::NamespaceTrait;
10use scirs2_autograd::VariableEnvironment;
11use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13/// Variable environment for managing trainable parameters
14///
15/// Provides save/load functionality for model persistence. For training,
16/// use PyTorch/TensorFlow and transfer weights via NumPy arrays.
17///
18/// Example:
19///     env = scirs2.VariableEnvironment()
20///     # Set variables using NumPy arrays
21///     var_id = env.set_variable("weights", np.random.randn(784, 128))
22///     # Save trained model
23///     env.save("model.json")
24///     # Load model
25///     loaded_env = scirs2.VariableEnvironment.load("model.json")
26#[pyclass(name = "VariableEnvironment", unsendable)]
27pub struct PyVariableEnvironment {
28    inner: VariableEnvironment<f64>,
29}
30
31#[pymethods]
32impl PyVariableEnvironment {
33    /// Create a new variable environment
34    #[new]
35    fn new() -> Self {
36        Self {
37            inner: VariableEnvironment::new(),
38        }
39    }
40
41    /// Set a named variable from a NumPy array
42    ///
43    /// Args:
44    ///     name (str): Variable name
45    ///     array (np.ndarray): Variable value (1D or 2D float64 array)
46    #[allow(deprecated)]
47    fn set_variable(&mut self, name: &str, array: &Bound<'_, PyAny>) -> PyResult<()> {
48        // Try 1D array first
49        if let Ok(arr1d) = array.downcast::<PyArray1<f64>>() {
50            let binding = arr1d.readonly();
51            let data = binding.as_array().to_owned();
52            self.inner.name(name).set(data);
53            return Ok(());
54        }
55
56        // Try 2D array
57        if let Ok(arr2d) = array.downcast::<PyArray2<f64>>() {
58            let binding = arr2d.readonly();
59            let data = binding.as_array().to_owned();
60            self.inner.name(name).set(data);
61            return Ok(());
62        }
63
64        Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
65            "Array must be 1D or 2D float64 numpy array",
66        ))
67    }
68
69    /// Get a named variable as a NumPy array
70    ///
71    /// Args:
72    ///     name (str): Variable name
73    ///
74    /// Returns:
75    ///     np.ndarray: Variable value
76    #[allow(deprecated)]
77    fn get_variable(&self, py: Python, name: &str) -> PyResult<Py<PyAny>> {
78        let namespace = self.inner.default_namespace();
79        let array_ref = namespace.get_array_by_name(name).ok_or_else(|| {
80            PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Variable '{}' not found", name))
81        })?;
82
83        let array = array_ref.borrow();
84
85        // Return based on dimensionality
86        match array.ndim() {
87            1 => {
88                let arr1d = array
89                    .view()
90                    .into_dimensionality::<scirs2_core::ndarray::Ix1>()
91                    .map_err(|e| {
92                        PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
93                            "Dimension error: {}",
94                            e
95                        ))
96                    })?;
97                Ok(arr1d.to_owned().into_pyarray(py).unbind().into())
98            }
99            2 => {
100                let arr2d = array
101                    .view()
102                    .into_dimensionality::<scirs2_core::ndarray::Ix2>()
103                    .map_err(|e| {
104                        PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
105                            "Dimension error: {}",
106                            e
107                        ))
108                    })?;
109                Ok(arr2d.to_owned().into_pyarray(py).unbind().into())
110            }
111            _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
112                "Only 1D and 2D arrays are currently supported",
113            )),
114        }
115    }
116
117    /// List all variable names in the default namespace
118    ///
119    /// Returns:
120    ///     list: List of variable names
121    fn list_variables(&self) -> Vec<String> {
122        self.inner
123            .default_namespace()
124            .current_var_names()
125            .into_iter()
126            .map(|s: &str| s.to_string())
127            .collect()
128    }
129
130    /// Save the variable environment to a file
131    ///
132    /// Args:
133    ///     path (str): Path to save file (.json format)
134    fn save(&self, path: &str) -> PyResult<()> {
135        self.inner
136            .save(path)
137            .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Save error: {}", e)))
138    }
139
140    /// Load a variable environment from a file
141    ///
142    /// Args:
143    ///     path (str): Path to load file (.json format)
144    ///
145    /// Returns:
146    ///     VariableEnvironment: Loaded environment
147    #[staticmethod]
148    fn load(path: &str) -> PyResult<Self> {
149        let inner = VariableEnvironment::<f64>::load(path).map_err(|e| {
150            PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Load error: {}", e))
151        })?;
152
153        Ok(Self { inner })
154    }
155
156    /// Get the number of variables in the environment
157    fn __len__(&self) -> usize {
158        self.inner.default_namespace().current_var_names().len()
159    }
160
161    /// String representation
162    fn __repr__(&self) -> String {
163        format!(
164            "VariableEnvironment({} variables)",
165            self.inner.default_namespace().current_var_names().len()
166        )
167    }
168}
169
170// ============================================================================
171// Module Registration
172// ============================================================================
173
174pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
175    // Register VariableEnvironment class
176    m.add_class::<PyVariableEnvironment>()?;
177
178    // Add module documentation
179    m.add("__doc__", "Automatic differentiation and variable management\n\n\
180        This module provides model parameter storage and persistence from scirs2-autograd.\n\
181        Due to Rust lifetime complexities in the computational graph API, full autodiff\n\
182        functionality is not exposed. For neural network training, we recommend:\n\n\
183        - PyTorch: Industry-standard deep learning framework\n\
184        - TensorFlow: Comprehensive ML platform\n\n\
185        scirs2 arrays are NumPy-compatible, enabling seamless integration with these frameworks.\n\n\
186        Use VariableEnvironment for:\n\
187        - Storing and managing model parameters\n\
188        - Saving/loading trained model weights\n\
189        - Transferring weights between scirs2 and PyTorch/TensorFlow")?;
190
191    Ok(())
192}