torsh_python/lib.rs
1//! Python bindings for ToRSh - PyTorch-compatible deep learning in Rust
2//!
3//! This crate provides Python bindings for the ToRSh deep learning framework,
4//! enabling PyTorch-compatible APIs to be used from Python.
5//!
6//! # Modular Structure
7//!
8//! The crate is organized into focused modules:
9//! - `tensor` - Tensor operations and creation functions
10//! - `nn` - Neural network layers and containers
11//! - `optim` - Optimization algorithms
12//! - `device` - Device management and utilities
13//! - `dtype` - Data type definitions and conversions
14//! - `error` - Error handling and conversions
15//! - `utils` - Common utilities and helpers
16
17use pyo3::prelude::*;
18
19// Core modules - modular structure
20pub mod device;
21pub mod dtype;
22pub mod error;
23// pub mod nn; // Temporarily disabled due to scirs2-autograd conflicts
24// pub mod optim; // Temporarily disabled due to scirs2-autograd conflicts
25// pub mod tensor; // Temporarily disabled due to scirs2-autograd conflicts
26pub mod utils;
27
28// Legacy modules (temporarily kept for compatibility)
29// pub mod autograd; // Temporarily disabled due to scirs2 API incompatibilities
30// pub mod distributed; // Temporarily disabled for compilation
31// pub mod functional; // Fixed for PyO3 0.25 but disabled until tensor ops are implemented
32
33// Re-export main types
34pub use device::PyDevice;
35pub use dtype::PyDType;
36pub use error::TorshPyError;
37// pub use tensor::PyTensor; // Temporarily disabled due to scirs2-autograd conflicts
38
39/// ToRSh Python module
40#[pymodule]
41fn torsh(m: &Bound<'_, PyModule>) -> PyResult<()> {
42 // Register main classes
43 // m.add_class::<PyTensor>()?; // Temporarily disabled due to scirs2-autograd conflicts
44 m.add_class::<PyDevice>()?;
45 m.add_class::<PyDType>()?;
46
47 // Add submodules with new modular structure
48 // let nn_module = PyModule::new(m.py(), "nn")?;
49 // nn::register_nn_module(m.py(), &nn_module)?;
50 // m.add_submodule(&nn_module)?;
51
52 // let optim_module = PyModule::new(m.py(), "optim")?;
53 // optim::register_optim_module(m.py(), &optim_module)?;
54 // m.add_submodule(&optim_module)?;
55
56 // let autograd_module = PyModule::new(m.py(), "autograd")?;
57 // autograd::register_autograd_module(m.py(), &autograd_module)?;
58 // m.add_submodule(&autograd_module)?;
59
60 // let distributed_module = PyModule::new(m.py(), "distributed")?;
61 // distributed::register_distributed_module(m.py(), &distributed_module)?;
62 // m.add_submodule(&distributed_module)?;
63
64 // let functional_module = PyModule::new(m.py(), "F")?;
65 // functional::register_functional_module(m.py(), &functional_module)?;
66 // m.add_submodule(&functional_module)?;
67
68 // Add tensor creation functions
69 // tensor::register_creation_functions(m)?; // Disabled: tensor module commented out
70
71 // Add device and dtype constants
72 device::register_device_constants(m)?;
73 dtype::register_dtype_constants(m)?;
74
75 // Register error types
76 error::register_error_types(m)?;
77
78 // Set version
79 m.add("__version__", env!("CARGO_PKG_VERSION"))?;
80
81 Ok(())
82}
83
84/// A Python module implemented in Rust.
85#[pymodule]
86fn torsh_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
87 torsh(m)
88}