torsh_python/lib.rs
1//! Python bindings for ToRSh - PyTorch-compatible deep learning in Rust
2//!
3//! This crate provides Python bindings for the ToRSh deep learning framework,
4//! enabling PyTorch-compatible APIs to be used from Python.
5//!
6//! # Modular Structure
7//!
8//! The crate is organized into focused modules:
9//! - `tensor` - Tensor operations and creation functions
10//! - `nn` - Neural network layers and containers
11//! - `optim` - Optimization algorithms
12//! - `device` - Device management and utilities
13//! - `dtype` - Data type definitions and conversions
14//! - `error` - Error handling and conversions
15//! - `utils` - Common utilities and helpers
16
17use pyo3::prelude::*;
18
19// Core modules - modular structure
20pub mod device;
21pub mod dtype;
22pub mod error;
23pub mod nn;
24pub mod optim;
25pub mod tensor;
26pub mod utils;
27
28// Legacy modules (temporarily kept for compatibility)
29// pub mod autograd; // Temporarily disabled due to scirs2 API incompatibilities
30// pub mod distributed; // Temporarily disabled for compilation
31// pub mod functional; // Fixed for PyO3 0.25 but disabled until tensor ops are implemented
32
33// Re-export main types
34pub use device::PyDevice;
35pub use dtype::PyDType;
36pub use error::TorshPyError;
37pub use tensor::PyTensor;
38
39/// ToRSh Python module
40#[pymodule]
41fn torsh(m: &Bound<'_, PyModule>) -> PyResult<()> {
42 // Register main classes
43 m.add_class::<PyTensor>()?;
44 m.add_class::<PyDevice>()?;
45 m.add_class::<PyDType>()?;
46
47 // Add submodules with new modular structure
48 nn::register_nn_module(m.py(), m)?;
49 optim::register_optim_module(m.py(), m)?;
50
51 // let autograd_module = PyModule::new(m.py(), "autograd")?;
52 // autograd::register_autograd_module(m.py(), &autograd_module)?;
53 // m.add_submodule(&autograd_module)?;
54
55 // let distributed_module = PyModule::new(m.py(), "distributed")?;
56 // distributed::register_distributed_module(m.py(), &distributed_module)?;
57 // m.add_submodule(&distributed_module)?;
58
59 // let functional_module = PyModule::new(m.py(), "F")?;
60 // functional::register_functional_module(m.py(), &functional_module)?;
61 // m.add_submodule(&functional_module)?;
62
63 // Add tensor creation functions
64 tensor::register_creation_functions(m)?;
65
66 // Add device and dtype constants
67 device::register_device_constants(m)?;
68 dtype::register_dtype_constants(m)?;
69
70 // Register error types
71 error::register_error_types(m)?;
72
73 // Set version
74 m.add("__version__", env!("CARGO_PKG_VERSION"))?;
75
76 Ok(())
77}
78
79/// A Python module implemented in Rust.
80#[pymodule]
81fn torsh_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
82 torsh(m)
83}