Skip to main content

torsh_python/nn/
linear.rs

1//! Linear (fully connected) neural network layer
2
3use super::module::PyModule;
4use crate::{device::PyDevice, error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6use std::collections::HashMap;
7use torsh_tensor::Tensor;
8
9/// Linear (fully connected) layer
10#[pyclass(name = "Linear", extends = PyModule)]
11pub struct PyLinear {
12    weight: Tensor<f32>,
13    bias: Option<Tensor<f32>>,
14    in_features: usize,
15    out_features: usize,
16    has_bias: bool,
17    training: bool,
18}
19
20#[pymethods]
21impl PyLinear {
22    #[new]
23    fn new(
24        in_features: usize,
25        out_features: usize,
26        bias: Option<bool>,
27    ) -> PyResult<(Self, PyModule)> {
28        let has_bias = bias.unwrap_or(true);
29
30        // Initialize weight with Xavier/Glorot uniform initialization
31        let weight_shape = vec![out_features, in_features];
32        let weight = py_result!(torsh_tensor::creation::randn(&weight_shape))?.requires_grad_(true);
33
34        // Initialize bias if needed
35        let bias = if has_bias {
36            let bias_shape = vec![out_features];
37            Some(py_result!(torsh_tensor::creation::zeros(&bias_shape))?.requires_grad_(true))
38        } else {
39            None
40        };
41
42        Ok((
43            Self {
44                weight,
45                bias,
46                in_features,
47                out_features,
48                has_bias,
49                training: true,
50            },
51            PyModule::new(),
52        ))
53    }
54
55    /// Forward pass through the linear layer
56    fn forward(&self, input: &PyTensor) -> PyResult<PyTensor> {
57        // Compute input @ weight.T
58        let result = py_result!(input.tensor.matmul(&self.weight))?;
59
60        // Add bias if present
61        let result = if let Some(ref bias) = self.bias {
62            py_result!(result.add(bias))?
63        } else {
64            result
65        };
66
67        Ok(PyTensor { tensor: result })
68    }
69
70    /// Get all parameters (weight and bias if present)
71    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
72        let mut params = Vec::new();
73
74        // Add weight parameter
75        params.push(PyTensor {
76            tensor: self.weight.clone(),
77        });
78
79        // Add bias parameter if present
80        if let Some(ref bias) = self.bias {
81            params.push(PyTensor {
82                tensor: bias.clone(),
83            });
84        }
85
86        Ok(params)
87    }
88
89    /// Get named parameters
90    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
91        let mut named_params = HashMap::new();
92
93        // Add weight parameter
94        named_params.insert(
95            "weight".to_string(),
96            PyTensor {
97                tensor: self.weight.clone(),
98            },
99        );
100
101        // Add bias parameter if present
102        if let Some(ref bias) = self.bias {
103            named_params.insert(
104                "bias".to_string(),
105                PyTensor {
106                    tensor: bias.clone(),
107                },
108            );
109        }
110
111        Ok(named_params)
112    }
113
114    /// Set training mode
115    fn train(&mut self, mode: Option<bool>) {
116        self.training = mode.unwrap_or(true);
117        // Linear layers don't have different behavior in train/eval mode
118        // but we track the state for consistency
119    }
120
121    /// Set evaluation mode
122    fn eval(&mut self) {
123        self.training = false;
124    }
125
126    /// Move layer to specified device
127    fn to(&mut self, device: PyDevice) -> PyResult<()> {
128        // Move weight to device
129        self.weight = py_result!(self.weight.clone().to(device.device))?;
130
131        // Move bias to device if present
132        if let Some(ref bias) = self.bias {
133            self.bias = Some(py_result!(bias.clone().to(device.device))?);
134        }
135
136        Ok(())
137    }
138
139    /// Zero gradients of all parameters
140    fn zero_grad(&mut self) {
141        // Zero gradients for weight and bias
142        let _ = self.weight.zero_grad();
143        if let Some(ref mut bias) = self.bias {
144            let _ = bias.zero_grad();
145        }
146    }
147
148    /// String representation
149    fn __repr__(&self) -> String {
150        format!(
151            "Linear(in_features={}, out_features={}, bias={})",
152            self.in_features, self.out_features, self.has_bias
153        )
154    }
155
156    /// Get input features
157    #[getter]
158    fn in_features(&self) -> usize {
159        self.in_features
160    }
161
162    /// Get output features
163    #[getter]
164    fn out_features(&self) -> usize {
165        self.out_features
166    }
167
168    /// Check if bias is enabled
169    #[getter]
170    fn bias(&self) -> bool {
171        self.has_bias
172    }
173
174    /// Check if module is in training mode
175    fn training(&self) -> bool {
176        self.training
177    }
178
179    /// Get weight tensor
180    #[getter]
181    fn weight(&self) -> PyResult<PyTensor> {
182        Ok(PyTensor {
183            tensor: self.weight.clone(),
184        })
185    }
186
187    /// Load state dictionary
188    fn load_state_dict(&mut self, state_dict: HashMap<String, PyTensor>) -> PyResult<()> {
189        // Load weight
190        if let Some(weight_tensor) = state_dict.get("weight") {
191            self.weight = weight_tensor.tensor.clone();
192        }
193
194        // Load bias if present
195        if self.has_bias {
196            if let Some(bias_tensor) = state_dict.get("bias") {
197                self.bias = Some(bias_tensor.tensor.clone());
198            }
199        }
200
201        Ok(())
202    }
203}