Skip to main content

torsh_python/optim/
mod.rs

1//! Optimization algorithms module - PyTorch-compatible optimizers
2//!
3//! This module provides a modular structure for optimization algorithms:
4//! - `base` - Base PyOptimizer class and common functionality
5//! - `sgd` - Stochastic Gradient Descent optimizer
6//! - `adam` - Adam and AdamW optimizers
7//! - `adagrad` - Adagrad optimizer
8//! - `rmsprop` - RMSprop optimizer
9
10pub mod adagrad;
11pub mod adam;
12pub mod base;
13pub mod rmsprop;
14pub mod sgd;
15
16// Re-export the main types
17pub use adagrad::PyAdaGrad;
18pub use adam::{PyAdam, PyAdamW};
19pub use base::PyOptimizer;
20pub use rmsprop::PyRMSprop;
21pub use sgd::PySGD;
22
23use pyo3::prelude::*;
24use pyo3::types::{PyModule, PyModuleMethods};
25
26/// Register the optim module with Python
27pub fn register_optim_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
28    // Register base optimizer
29    m.add_class::<PyOptimizer>()?;
30
31    // Register specific optimizers
32    m.add_class::<PySGD>()?;
33    m.add_class::<PyAdam>()?;
34    m.add_class::<PyAdamW>()?;
35    m.add_class::<PyAdaGrad>()?;
36    m.add_class::<PyRMSprop>()?;
37
38    Ok(())
39}