Skip to main content

torsh_python/
lib.rs

1//! Python bindings for ToRSh - PyTorch-compatible deep learning in Rust
2//!
3//! This crate provides Python bindings for the ToRSh deep learning framework,
4//! enabling PyTorch-compatible APIs to be used from Python.
5//!
6//! # Modular Structure
7//!
8//! The crate is organized into focused modules:
9//! - `tensor` - Tensor operations and creation functions
10//! - `nn` - Neural network layers and containers
11//! - `optim` - Optimization algorithms
12//! - `device` - Device management and utilities
13//! - `dtype` - Data type definitions and conversions
14//! - `error` - Error handling and conversions
15//! - `utils` - Common utilities and helpers
16
17use pyo3::prelude::*;
18
19// Core modules - modular structure
20pub mod device;
21pub mod dtype;
22pub mod error;
23pub mod nn;
24pub mod optim;
25pub mod tensor;
26pub mod utils;
27
28// Legacy modules (temporarily kept for compatibility)
29// pub mod autograd;  // Temporarily disabled due to scirs2 API incompatibilities
30// pub mod distributed;  // Temporarily disabled for compilation
31// pub mod functional;  // Fixed for PyO3 0.25 but disabled until tensor ops are implemented
32
33// Re-export main types
34pub use device::PyDevice;
35pub use dtype::PyDType;
36pub use error::TorshPyError;
37pub use tensor::PyTensor;
38
39/// ToRSh Python module
40#[pymodule]
41fn torsh(m: &Bound<'_, PyModule>) -> PyResult<()> {
42    // Register main classes
43    m.add_class::<PyTensor>()?;
44    m.add_class::<PyDevice>()?;
45    m.add_class::<PyDType>()?;
46
47    // Add submodules with new modular structure
48    nn::register_nn_module(m.py(), m)?;
49    optim::register_optim_module(m.py(), m)?;
50
51    // let autograd_module = PyModule::new(m.py(), "autograd")?;
52    // autograd::register_autograd_module(m.py(), &autograd_module)?;
53    // m.add_submodule(&autograd_module)?;
54
55    // let distributed_module = PyModule::new(m.py(), "distributed")?;
56    // distributed::register_distributed_module(m.py(), &distributed_module)?;
57    // m.add_submodule(&distributed_module)?;
58
59    // let functional_module = PyModule::new(m.py(), "F")?;
60    // functional::register_functional_module(m.py(), &functional_module)?;
61    // m.add_submodule(&functional_module)?;
62
63    // Add tensor creation functions
64    tensor::register_creation_functions(m)?;
65
66    // Add device and dtype constants
67    device::register_device_constants(m)?;
68    dtype::register_dtype_constants(m)?;
69
70    // Register error types
71    error::register_error_types(m)?;
72
73    // Set version
74    m.add("__version__", env!("CARGO_PKG_VERSION"))?;
75
76    Ok(())
77}
78
79/// A Python module implemented in Rust.
80#[pymodule]
81fn torsh_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
82    torsh(m)
83}