1use pyo3::prelude::*;
4
5mod dataloader;
6mod functional;
7mod module;
8mod optimizer;
9pub mod tensor;
10mod utils;
11
12pub use dataloader::{PyDataLoader, PyDataLoaderBuilder, PyRandomDataLoader};
13pub use functional::*;
14pub use module::{PyLinear, PyModule};
15pub use optimizer::{PyAdam, PyOptimizer, PySGD};
16pub use tensor::PyTensor;
17pub use utils::*;
18
19pub use crate::pandas_support::{DataAnalysisResult, PandasSupport, TorshDataFrame, TorshSeries};
25pub use crate::scipy_integration::{
28 LinalgResult, OptimizationResult, SciPyIntegration, SignalResult,
29};
30
31#[pymodule]
33fn torsh(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
34 m.add_class::<PyTensor>()?;
36
37 m.add_class::<PyLinear>()?;
39
40 m.add_class::<PySGD>()?;
42 m.add_class::<PyAdam>()?;
43
44 m.add_class::<PyDataLoader>()?;
46 m.add_class::<PyRandomDataLoader>()?;
47 m.add_class::<PyDataLoaderBuilder>()?;
48
49 m.add_function(wrap_pyfunction!(functional::relu, m)?)?;
52 m.add_function(wrap_pyfunction!(functional::sigmoid, m)?)?;
53 m.add_function(wrap_pyfunction!(functional::tanh, m)?)?;
54 m.add_function(wrap_pyfunction!(functional::softmax, m)?)?;
55 m.add_function(wrap_pyfunction!(functional::cross_entropy, m)?)?;
56 m.add_function(wrap_pyfunction!(functional::mse_loss, m)?)?;
57 m.add_function(wrap_pyfunction!(functional::binary_cross_entropy, m)?)?;
58 m.add_function(wrap_pyfunction!(functional::gelu, m)?)?;
59 m.add_function(wrap_pyfunction!(functional::log_softmax, m)?)?;
60
61 m.add_function(wrap_pyfunction!(utils::tensor, m)?)?;
63 m.add_function(wrap_pyfunction!(utils::zeros, m)?)?;
64 m.add_function(wrap_pyfunction!(utils::ones, m)?)?;
65 m.add_function(wrap_pyfunction!(utils::randn, m)?)?;
66 m.add_function(wrap_pyfunction!(utils::rand, m)?)?;
67 m.add_function(wrap_pyfunction!(utils::eye, m)?)?;
68 m.add_function(wrap_pyfunction!(utils::full, m)?)?;
69 m.add_function(wrap_pyfunction!(utils::linspace, m)?)?;
70 m.add_function(wrap_pyfunction!(utils::arange, m)?)?;
71 m.add_function(wrap_pyfunction!(utils::stack, m)?)?;
72 m.add_function(wrap_pyfunction!(utils::cat, m)?)?;
73 m.add_function(wrap_pyfunction!(utils::from_numpy, m)?)?;
74 m.add_function(wrap_pyfunction!(utils::to_numpy, m)?)?;
75 m.add_function(wrap_pyfunction!(utils::manual_seed, m)?)?;
76
77 m.add_function(wrap_pyfunction!(dataloader::create_dataloader, m)?)?;
79 m.add_function(wrap_pyfunction!(dataloader::create_dataset_from_array, m)?)?;
80
81 #[cfg(feature = "python")]
83 crate::error::python_exceptions::register_exceptions(m)?;
84 m.add_function(wrap_pyfunction!(dataloader::get_dataloader_info, m)?)?;
85 m.add_function(wrap_pyfunction!(dataloader::benchmark_dataloader, m)?)?;
86
87 m.add_class::<SciPyIntegration>()?;
89 m.add_class::<OptimizationResult>()?;
90 m.add_class::<LinalgResult>()?;
91 m.add_class::<SignalResult>()?;
92
93 m.add_class::<PandasSupport>()?;
94 m.add_class::<TorshDataFrame>()?;
95 m.add_class::<TorshSeries>()?;
96 m.add_class::<DataAnalysisResult>()?;
97
98 let scipy_utils = crate::scipy_integration::create_scipy_utilities(m.py())?;
111 m.add("scipy", scipy_utils)?;
112
113 let pandas_utils = crate::pandas_support::create_pandas_utilities(m.py())?;
114 m.add("pandas", pandas_utils)?;
115
116 m.add("__version__", env!("CARGO_PKG_VERSION"))?;
125
126 m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
128 m.add_function(wrap_pyfunction!(cuda_device_count, m)?)?;
129
130 Ok(())
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use pyo3::Python;
137
138 #[test]
139 fn test_module_creation() {
140 Python::initialize();
141 Python::attach(|py| {
142 let module = pyo3::types::PyModule::new(py, "test_torsh").unwrap();
143 let result = torsh(&module);
144 assert!(result.is_ok());
145 });
146 }
147}