torsh_ffi/python/
mod.rs

1//! Python bindings for ToRSh via PyO3
2
3use pyo3::prelude::*;
4
5mod dataloader;
6mod functional;
7mod module;
8mod optimizer;
9pub mod tensor;
10mod utils;
11
12pub use dataloader::{PyDataLoader, PyDataLoaderBuilder, PyRandomDataLoader};
13pub use functional::*;
14pub use module::{PyLinear, PyModule};
15pub use optimizer::{PyAdam, PyOptimizer, PySGD};
16pub use tensor::PyTensor;
17pub use utils::*;
18
19// Re-export integration modules
20// TEMPORARILY DISABLED DUE TO PyO3 API COMPATIBILITY ISSUES
21// pub use crate::jupyter_widgets::{
22//     DataExplorationWidget, JupyterWidgets, TensorVisualizationWidget, TrainingMonitorWidget,
23// };
24pub use crate::pandas_support::{DataAnalysisResult, PandasSupport, TorshDataFrame, TorshSeries};
25// TEMPORARILY DISABLED DUE TO PyO3 API COMPATIBILITY ISSUES
26// pub use crate::plotting_utilities::{PlotResult, PlottingUtilities, StatPlotConfig};
27pub use crate::scipy_integration::{
28    LinalgResult, OptimizationResult, SciPyIntegration, SignalResult,
29};
30
31/// Initialize the Python module
32#[pymodule]
33fn torsh(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
34    // Add tensor class
35    m.add_class::<PyTensor>()?;
36
37    // Add neural network modules
38    m.add_class::<PyLinear>()?;
39
40    // Add optimizers
41    m.add_class::<PySGD>()?;
42    m.add_class::<PyAdam>()?;
43
44    // Add data loaders
45    m.add_class::<PyDataLoader>()?;
46    m.add_class::<PyRandomDataLoader>()?;
47    m.add_class::<PyDataLoaderBuilder>()?;
48
49    // Add functional operations
50    // Add functional operations directly to main module
51    m.add_function(wrap_pyfunction!(functional::relu, m)?)?;
52    m.add_function(wrap_pyfunction!(functional::sigmoid, m)?)?;
53    m.add_function(wrap_pyfunction!(functional::tanh, m)?)?;
54    m.add_function(wrap_pyfunction!(functional::softmax, m)?)?;
55    m.add_function(wrap_pyfunction!(functional::cross_entropy, m)?)?;
56    m.add_function(wrap_pyfunction!(functional::mse_loss, m)?)?;
57    m.add_function(wrap_pyfunction!(functional::binary_cross_entropy, m)?)?;
58    m.add_function(wrap_pyfunction!(functional::gelu, m)?)?;
59    m.add_function(wrap_pyfunction!(functional::log_softmax, m)?)?;
60
61    // Add utility functions
62    m.add_function(wrap_pyfunction!(utils::tensor, m)?)?;
63    m.add_function(wrap_pyfunction!(utils::zeros, m)?)?;
64    m.add_function(wrap_pyfunction!(utils::ones, m)?)?;
65    m.add_function(wrap_pyfunction!(utils::randn, m)?)?;
66    m.add_function(wrap_pyfunction!(utils::rand, m)?)?;
67    m.add_function(wrap_pyfunction!(utils::eye, m)?)?;
68    m.add_function(wrap_pyfunction!(utils::full, m)?)?;
69    m.add_function(wrap_pyfunction!(utils::linspace, m)?)?;
70    m.add_function(wrap_pyfunction!(utils::arange, m)?)?;
71    m.add_function(wrap_pyfunction!(utils::stack, m)?)?;
72    m.add_function(wrap_pyfunction!(utils::cat, m)?)?;
73    m.add_function(wrap_pyfunction!(utils::from_numpy, m)?)?;
74    m.add_function(wrap_pyfunction!(utils::to_numpy, m)?)?;
75    m.add_function(wrap_pyfunction!(utils::manual_seed, m)?)?;
76
77    // Add dataloader functions
78    m.add_function(wrap_pyfunction!(dataloader::create_dataloader, m)?)?;
79    m.add_function(wrap_pyfunction!(dataloader::create_dataset_from_array, m)?)?;
80
81    // Register custom exception types
82    #[cfg(feature = "python")]
83    crate::error::python_exceptions::register_exceptions(m)?;
84    m.add_function(wrap_pyfunction!(dataloader::get_dataloader_info, m)?)?;
85    m.add_function(wrap_pyfunction!(dataloader::benchmark_dataloader, m)?)?;
86
87    // Add integration utilities classes
88    m.add_class::<SciPyIntegration>()?;
89    m.add_class::<OptimizationResult>()?;
90    m.add_class::<LinalgResult>()?;
91    m.add_class::<SignalResult>()?;
92
93    m.add_class::<PandasSupport>()?;
94    m.add_class::<TorshDataFrame>()?;
95    m.add_class::<TorshSeries>()?;
96    m.add_class::<DataAnalysisResult>()?;
97
98    // TEMPORARILY DISABLED DUE TO PyO3 API COMPATIBILITY ISSUES
99    // m.add_class::<PlottingUtilities>()?;
100    // m.add_class::<PlotResult>()?;
101    // TEMPORARILY DISABLED DUE TO PyO3 API COMPATIBILITY ISSUES
102    // m.add_class::<StatPlotConfig>()?;
103
104    // m.add_class::<JupyterWidgets>()?;
105    // m.add_class::<TensorVisualizationWidget>()?;
106    // m.add_class::<TrainingMonitorWidget>()?;
107    // m.add_class::<DataExplorationWidget>()?;
108
109    // Create submodules for integration utilities
110    let scipy_utils = crate::scipy_integration::create_scipy_utilities(m.py())?;
111    m.add("scipy", scipy_utils)?;
112
113    let pandas_utils = crate::pandas_support::create_pandas_utilities(m.py())?;
114    m.add("pandas", pandas_utils)?;
115
116    // TEMPORARILY DISABLED DUE TO PyO3 API COMPATIBILITY ISSUES
117    // let plotting_utils = crate::plotting_utilities::create_plotting_utilities(m.py())?;
118    // m.add("plotting", plotting_utils)?;
119
120    // let jupyter_utils = crate::jupyter_widgets::create_jupyter_utilities(m.py())?;
121    // m.add("jupyter", jupyter_utils)?;
122
123    // Add constants
124    m.add("__version__", env!("CARGO_PKG_VERSION"))?;
125
126    // Add device information
127    m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
128    m.add_function(wrap_pyfunction!(cuda_device_count, m)?)?;
129
130    Ok(())
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use pyo3::Python;
137
138    #[test]
139    fn test_module_creation() {
140        Python::initialize();
141        Python::attach(|py| {
142            let module = pyo3::types::PyModule::new(py, "test_torsh").unwrap();
143            let result = torsh(&module);
144            assert!(result.is_ok());
145        });
146    }
147}