Skip to main content

torsh_python/nn/
loss.rs

1//! Loss functions for neural networks
2
3use super::module::PyModule;
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6
7/// Mean Squared Error loss
8#[pyclass(name = "MSELoss", extends = PyModule)]
9pub struct PyMSELoss {
10    reduction: String,
11    training: bool,
12}
13
14#[pymethods]
15impl PyMSELoss {
16    #[new]
17    fn new(reduction: Option<String>) -> (Self, PyModule) {
18        let reduction = reduction.unwrap_or_else(|| "mean".to_string());
19        (
20            Self {
21                reduction,
22                training: true,
23            },
24            PyModule::new(),
25        )
26    }
27
28    /// Forward pass through MSE Loss
29    fn forward(&self, input: &PyTensor, target: &PyTensor) -> PyResult<PyTensor> {
30        // MSE = mean((input - target)^2)
31        let diff = py_result!(input.tensor.sub(&target.tensor))?;
32        let squared = py_result!(diff.pow(2.0))?;
33        let result = py_result!(squared.mean(None, false))?;
34        Ok(PyTensor { tensor: result })
35    }
36
37    /// String representation
38    fn __repr__(&self) -> String {
39        format!("MSELoss(reduction='{}')", self.reduction)
40    }
41
42    /// Set training mode
43    fn train(&mut self, mode: Option<bool>) {
44        self.training = mode.unwrap_or(true);
45    }
46
47    /// Set evaluation mode
48    fn eval(&mut self) {
49        self.training = false;
50    }
51
52    /// Check if module is in training mode
53    fn training(&self) -> bool {
54        self.training
55    }
56}
57
58/// Cross Entropy loss
59#[pyclass(name = "CrossEntropyLoss", extends = PyModule)]
60pub struct PyCrossEntropyLoss {
61    reduction: String,
62    training: bool,
63}
64
65#[pymethods]
66impl PyCrossEntropyLoss {
67    #[new]
68    fn new(reduction: Option<String>) -> (Self, PyModule) {
69        let reduction = reduction.unwrap_or_else(|| "mean".to_string());
70        (
71            Self {
72                reduction,
73                training: true,
74            },
75            PyModule::new(),
76        )
77    }
78
79    /// Forward pass through Cross Entropy Loss
80    fn forward(&self, input: &PyTensor, target: &PyTensor) -> PyResult<PyTensor> {
81        // For now, use a placeholder - proper cross entropy needs more complex implementation
82        let result = py_result!(input.tensor.sub(&target.tensor))?;
83        Ok(PyTensor { tensor: result })
84    }
85
86    /// String representation
87    fn __repr__(&self) -> String {
88        format!("CrossEntropyLoss(reduction='{}')", self.reduction)
89    }
90
91    /// Set training mode
92    fn train(&mut self, mode: Option<bool>) {
93        self.training = mode.unwrap_or(true);
94    }
95
96    /// Set evaluation mode
97    fn eval(&mut self) {
98        self.training = false;
99    }
100
101    /// Check if module is in training mode
102    fn training(&self) -> bool {
103        self.training
104    }
105}
106
107/// Binary Cross Entropy loss
108#[pyclass(name = "BCELoss", extends = PyModule)]
109pub struct PyBCELoss {
110    reduction: String,
111    training: bool,
112}
113
114#[pymethods]
115impl PyBCELoss {
116    #[new]
117    fn new(reduction: Option<String>) -> (Self, PyModule) {
118        let reduction = reduction.unwrap_or_else(|| "mean".to_string());
119        (
120            Self {
121                reduction,
122                training: true,
123            },
124            PyModule::new(),
125        )
126    }
127
128    /// Forward pass through BCE Loss
129    fn forward(&self, input: &PyTensor, target: &PyTensor) -> PyResult<PyTensor> {
130        // For now, use a placeholder - proper BCE needs more complex implementation
131        let result = py_result!(input.tensor.sub(&target.tensor))?;
132        Ok(PyTensor { tensor: result })
133    }
134
135    /// String representation
136    fn __repr__(&self) -> String {
137        format!("BCELoss(reduction='{}')", self.reduction)
138    }
139
140    /// Set training mode
141    fn train(&mut self, mode: Option<bool>) {
142        self.training = mode.unwrap_or(true);
143    }
144
145    /// Set evaluation mode
146    fn eval(&mut self) {
147        self.training = false;
148    }
149
150    /// Check if module is in training mode
151    fn training(&self) -> bool {
152        self.training
153    }
154}