Skip to main content

torsh_python/nn/
mod.rs

1//! Neural Network module - PyTorch-compatible neural network layers and containers
2//!
3//! This module provides a modular structure for neural network components:
4//! - `module` - Base PyModule class and core functionality
5//! - `linear` - Linear/Dense layers
6//! - `container` - Sequential, ModuleList, and other containers
7//! - `activation` - Activation functions
8//! - `loss` - Loss functions
9//! - `conv` - Convolutional layers (Conv1d, Conv2d)
10//! - `normalization` - Normalization layers (BatchNorm, LayerNorm)
11//! - `dropout` - Dropout and regularization layers
12//! - `pooling` - Pooling layers (MaxPool, AvgPool, AdaptivePool)
13
14pub mod activation;
15pub mod container;
16pub mod conv;
17pub mod dropout;
18pub mod linear;
19pub mod loss;
20pub mod module;
21pub mod normalization;
22pub mod pooling;
23
24// Re-export the main types
25pub use container::{PyModuleList, PySequential};
26pub use conv::{PyConv1d, PyConv2d};
27pub use dropout::{PyAlphaDropout, PyDropout, PyDropout2d, PyDropout3d};
28pub use linear::PyLinear;
29pub use module::PyModule as PyNNModule;
30pub use normalization::{PyBatchNorm1d, PyBatchNorm2d, PyLayerNorm};
31pub use pooling::{PyAdaptiveAvgPool2d, PyAdaptiveMaxPool2d, PyAvgPool2d, PyMaxPool2d};
32
33use pyo3::prelude::*;
34use pyo3::types::{PyModule, PyModuleMethods};
35
36/// Register the nn module with Python
37pub fn register_nn_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
38    // Register base module
39    m.add_class::<PyNNModule>()?;
40
41    // Register linear layers
42    m.add_class::<PyLinear>()?;
43
44    // Register convolutional layers
45    m.add_class::<PyConv1d>()?;
46    m.add_class::<PyConv2d>()?;
47
48    // Register normalization layers
49    m.add_class::<PyBatchNorm1d>()?;
50    m.add_class::<PyBatchNorm2d>()?;
51    m.add_class::<PyLayerNorm>()?;
52
53    // Register dropout layers
54    m.add_class::<PyDropout>()?;
55    m.add_class::<PyDropout2d>()?;
56    m.add_class::<PyDropout3d>()?;
57    m.add_class::<PyAlphaDropout>()?;
58
59    // Register pooling layers
60    m.add_class::<PyMaxPool2d>()?;
61    m.add_class::<PyAvgPool2d>()?;
62    m.add_class::<PyAdaptiveAvgPool2d>()?;
63    m.add_class::<PyAdaptiveMaxPool2d>()?;
64
65    // Register containers
66    m.add_class::<PySequential>()?;
67    m.add_class::<PyModuleList>()?;
68
69    Ok(())
70}