1#![allow(dead_code)]
5use crate::error::FfiError;
6use crate::python::tensor::PyTensor;
7use pyo3::prelude::*;
8
9#[pyclass(name = "Module", subclass)]
11#[derive(Clone)]
12pub struct PyModule {
13 name: String,
15}
16
17#[pymethods]
18impl PyModule {
19 fn forward(&self, _input: &PyTensor) -> PyResult<PyTensor> {
21 Err(FfiError::UnsupportedOperation {
22 operation: "forward not implemented for base Module".to_string(),
23 }
24 .into())
25 }
26
27 fn train(&mut self, mode: Option<bool>) {
29 let _training = mode.unwrap_or(true);
30 }
32
33 fn eval(&mut self) {
35 self.train(Some(false));
36 }
37
38 fn parameters(&self) -> Vec<PyTensor> {
40 Vec::new()
41 }
42
43 fn __repr__(&self) -> String {
44 format!("{}()", self.name)
45 }
46}
47
48#[pyclass(name = "Linear")]
50pub struct PyLinear {
51 in_features: usize,
52 out_features: usize,
53 bias: bool,
54 weight: PyTensor,
55 bias_tensor: Option<PyTensor>,
56}
57
58#[pymethods]
59impl PyLinear {
60 #[new]
61 fn new(in_features: usize, out_features: usize, bias: Option<bool>) -> PyResult<Self> {
62 let use_bias = bias.unwrap_or(true);
63
64 let weight_data: Vec<f32> = (0..out_features * in_features)
66 .map(|i| (i as f32) * 0.01 - 0.005) .collect();
68
69 let weight = Python::attach(|py| {
70 let data = pyo3::types::PyList::new(py, &weight_data)?;
71 PyTensor::new(
72 data.as_ref(),
73 Some(vec![out_features, in_features]),
74 Some("f32"),
75 true,
76 )
77 })?;
78
79 let bias_tensor = if use_bias {
80 let bias_data: Vec<f32> = (0..out_features).map(|_| 0.0).collect();
81 let bias_tensor = Python::attach(|py| {
82 let data = pyo3::types::PyList::new(py, &bias_data)?;
83 PyTensor::new(data.as_ref(), Some(vec![out_features]), Some("f32"), true)
84 })?;
85 Some(bias_tensor)
86 } else {
87 None
88 };
89
90 Ok(PyLinear {
91 in_features,
92 out_features,
93 bias: use_bias,
94 weight,
95 bias_tensor,
96 })
97 }
98
99 fn forward(&self, input: &PyTensor) -> PyResult<PyTensor> {
101 if input.shape().len() < 2 {
103 return Err(FfiError::ShapeMismatch {
104 expected: vec![0, self.in_features], actual: input.shape(),
106 }
107 .into());
108 }
109
110 let input_features = input.shape()[input.shape().len() - 1];
111 if input_features != self.in_features {
112 return Err(FfiError::ShapeMismatch {
113 expected: vec![self.in_features],
114 actual: vec![input_features],
115 }
116 .into());
117 }
118
119 if input.shape().len() != 2 {
122 return Err(FfiError::UnsupportedOperation {
123 operation: "Only 2D input currently supported".to_string(),
124 }
125 .into());
126 }
127
128 let weight_t = self
129 .weight
130 .t_internal()
131 .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{:?}", e)))?;
132 let output = input
133 .matmul_internal(&weight_t)
134 .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{:?}", e)))?;
135
136 if let Some(ref _bias) = self.bias_tensor {
138 for batch_idx in 0..input.shape()[0] {
141 for _feature_idx in 0..self.out_features {
142 let _output_idx = batch_idx * self.out_features + _feature_idx;
143 }
145 }
146 }
147
148 Ok(output)
149 }
150
151 #[getter]
152 fn weight(&self) -> PyTensor {
153 self.weight.clone()
154 }
155
156 #[getter]
157 fn bias(&self) -> Option<PyTensor> {
158 self.bias_tensor.clone()
159 }
160
161 #[getter]
162 fn in_features(&self) -> usize {
163 self.in_features
164 }
165
166 #[getter]
167 fn out_features(&self) -> usize {
168 self.out_features
169 }
170
171 fn __repr__(&self) -> String {
172 format!(
173 "Linear(in_features={}, out_features={}, bias={})",
174 self.in_features, self.out_features, self.bias
175 )
176 }
177}
178
179#[pyclass(name = "Conv2d")]
181pub struct PyConv2d {
182 in_channels: usize,
183 out_channels: usize,
184 kernel_size: (usize, usize),
185 stride: (usize, usize),
186 padding: (usize, usize),
187}
188
189#[pymethods]
190impl PyConv2d {
191 #[new]
192 fn new(
193 in_channels: usize,
194 out_channels: usize,
195 kernel_size: (usize, usize),
196 stride: Option<(usize, usize)>,
197 padding: Option<(usize, usize)>,
198 ) -> Self {
199 PyConv2d {
200 in_channels,
201 out_channels,
202 kernel_size,
203 stride: stride.unwrap_or((1, 1)),
204 padding: padding.unwrap_or((0, 0)),
205 }
206 }
207
208 fn forward(&self, _input: &PyTensor) -> PyResult<PyTensor> {
209 Err(FfiError::UnsupportedOperation {
210 operation: "Conv2d forward not yet implemented".to_string(),
211 }
212 .into())
213 }
214
215 fn __repr__(&self) -> String {
216 format!(
217 "Conv2d({}, {}, kernel_size={:?}, stride={:?}, padding={:?})",
218 self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding
219 )
220 }
221}
222
223#[pyclass(name = "ReLU")]
225pub struct PyReLU {
226 inplace: bool,
227}
228
229#[pymethods]
230impl PyReLU {
231 #[new]
232 fn new(inplace: Option<bool>) -> Self {
233 PyReLU {
234 inplace: inplace.unwrap_or(false),
235 }
236 }
237
238 fn forward(&self, input: &PyTensor) -> PyResult<PyTensor> {
239 let result_data: Vec<f32> = input.data.iter().map(|&x| x.max(0.0)).collect();
241
242 Python::attach(|py| {
243 let data = pyo3::types::PyList::new(py, &result_data)?;
244 PyTensor::new(
245 data.as_ref(),
246 Some(input.shape()),
247 Some("f32"),
248 input.requires_grad,
249 )
250 })
251 }
252
253 fn __repr__(&self) -> String {
254 if self.inplace {
255 "ReLU(inplace=True)".to_string()
256 } else {
257 "ReLU()".to_string()
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use pyo3::types::PyList;
266 use pyo3::Python;
267
268 #[test]
269 fn test_linear_creation() {
270 Python::initialize();
271 let linear = PyLinear::new(10, 5, None).unwrap();
272 assert_eq!(linear.in_features(), 10);
273 assert_eq!(linear.out_features(), 5);
274 }
275
276 #[test]
277 fn test_relu_forward() {
278 Python::initialize();
279 Python::attach(|py| {
280 let data = PyList::new(py, vec![-1.0, 0.0, 1.0, 2.0]).unwrap();
281 let input = PyTensor::new(data.as_ref(), None, None, false).unwrap();
282
283 let relu = PyReLU::new(None);
284 let output = relu.forward(&input).unwrap();
285
286 assert!(output.data[0] == 0.0);
288 assert!(output.data[1] == 0.0);
289 assert!(output.data[2] == 1.0);
290 assert!(output.data[3] == 2.0);
291 });
292 }
293}