torsh_python/
lib.rs

1//! Python bindings for ToRSh - PyTorch-compatible deep learning in Rust
2//!
3//! This crate provides Python bindings for the ToRSh deep learning framework,
4//! enabling PyTorch-compatible APIs to be used from Python.
5//!
6//! # Modular Structure
7//!
8//! The crate is organized into focused modules:
9//! - `tensor` - Tensor operations and creation functions
10//! - `nn` - Neural network layers and containers
11//! - `optim` - Optimization algorithms
12//! - `device` - Device management and utilities
13//! - `dtype` - Data type definitions and conversions
14//! - `error` - Error handling and conversions
15//! - `utils` - Common utilities and helpers
16
17use pyo3::prelude::*;
18
19// Core modules - modular structure
20pub mod device;
21pub mod dtype;
22pub mod error;
23// pub mod nn;  // Temporarily disabled due to scirs2-autograd conflicts
24// pub mod optim;  // Temporarily disabled due to scirs2-autograd conflicts
25// pub mod tensor;  // Temporarily disabled due to scirs2-autograd conflicts
26pub mod utils;
27
28// Legacy modules (temporarily kept for compatibility)
29// pub mod autograd;  // Temporarily disabled due to scirs2 API incompatibilities
30// pub mod distributed;  // Temporarily disabled for compilation
31// pub mod functional;  // Fixed for PyO3 0.25 but disabled until tensor ops are implemented
32
33// Re-export main types
34pub use device::PyDevice;
35pub use dtype::PyDType;
36pub use error::TorshPyError;
37// pub use tensor::PyTensor;  // Temporarily disabled due to scirs2-autograd conflicts
38
39/// ToRSh Python module
40#[pymodule]
41fn torsh(m: &Bound<'_, PyModule>) -> PyResult<()> {
42    // Register main classes
43    // m.add_class::<PyTensor>()?;  // Temporarily disabled due to scirs2-autograd conflicts
44    m.add_class::<PyDevice>()?;
45    m.add_class::<PyDType>()?;
46
47    // Add submodules with new modular structure
48    // let nn_module = PyModule::new(m.py(), "nn")?;
49    // nn::register_nn_module(m.py(), &nn_module)?;
50    // m.add_submodule(&nn_module)?;
51
52    // let optim_module = PyModule::new(m.py(), "optim")?;
53    // optim::register_optim_module(m.py(), &optim_module)?;
54    // m.add_submodule(&optim_module)?;
55
56    // let autograd_module = PyModule::new(m.py(), "autograd")?;
57    // autograd::register_autograd_module(m.py(), &autograd_module)?;
58    // m.add_submodule(&autograd_module)?;
59
60    // let distributed_module = PyModule::new(m.py(), "distributed")?;
61    // distributed::register_distributed_module(m.py(), &distributed_module)?;
62    // m.add_submodule(&distributed_module)?;
63
64    // let functional_module = PyModule::new(m.py(), "F")?;
65    // functional::register_functional_module(m.py(), &functional_module)?;
66    // m.add_submodule(&functional_module)?;
67
68    // Add tensor creation functions
69    // tensor::register_creation_functions(m)?; // Disabled: tensor module commented out
70
71    // Add device and dtype constants
72    device::register_device_constants(m)?;
73    dtype::register_dtype_constants(m)?;
74
75    // Register error types
76    error::register_error_types(m)?;
77
78    // Set version
79    m.add("__version__", env!("CARGO_PKG_VERSION"))?;
80
81    Ok(())
82}
83
84/// A Python module implemented in Rust.
85#[pymodule]
86fn torsh_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
87    torsh(m)
88}