1use super::module::PyModule;
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6
7#[pyclass(name = "MSELoss", extends = PyModule)]
9pub struct PyMSELoss {
10 reduction: String,
11 training: bool,
12}
13
14#[pymethods]
15impl PyMSELoss {
16 #[new]
17 fn new(reduction: Option<String>) -> (Self, PyModule) {
18 let reduction = reduction.unwrap_or_else(|| "mean".to_string());
19 (
20 Self {
21 reduction,
22 training: true,
23 },
24 PyModule::new(),
25 )
26 }
27
28 fn forward(&self, input: &PyTensor, target: &PyTensor) -> PyResult<PyTensor> {
30 let diff = py_result!(input.tensor.sub(&target.tensor))?;
32 let squared = py_result!(diff.pow(2.0))?;
33 let result = py_result!(squared.mean(None, false))?;
34 Ok(PyTensor { tensor: result })
35 }
36
37 fn __repr__(&self) -> String {
39 format!("MSELoss(reduction='{}')", self.reduction)
40 }
41
42 fn train(&mut self, mode: Option<bool>) {
44 self.training = mode.unwrap_or(true);
45 }
46
47 fn eval(&mut self) {
49 self.training = false;
50 }
51
52 fn training(&self) -> bool {
54 self.training
55 }
56}
57
58#[pyclass(name = "CrossEntropyLoss", extends = PyModule)]
60pub struct PyCrossEntropyLoss {
61 reduction: String,
62 training: bool,
63}
64
65#[pymethods]
66impl PyCrossEntropyLoss {
67 #[new]
68 fn new(reduction: Option<String>) -> (Self, PyModule) {
69 let reduction = reduction.unwrap_or_else(|| "mean".to_string());
70 (
71 Self {
72 reduction,
73 training: true,
74 },
75 PyModule::new(),
76 )
77 }
78
79 fn forward(&self, input: &PyTensor, target: &PyTensor) -> PyResult<PyTensor> {
81 let result = py_result!(input.tensor.sub(&target.tensor))?;
83 Ok(PyTensor { tensor: result })
84 }
85
86 fn __repr__(&self) -> String {
88 format!("CrossEntropyLoss(reduction='{}')", self.reduction)
89 }
90
91 fn train(&mut self, mode: Option<bool>) {
93 self.training = mode.unwrap_or(true);
94 }
95
96 fn eval(&mut self) {
98 self.training = false;
99 }
100
101 fn training(&self) -> bool {
103 self.training
104 }
105}
106
107#[pyclass(name = "BCELoss", extends = PyModule)]
109pub struct PyBCELoss {
110 reduction: String,
111 training: bool,
112}
113
114#[pymethods]
115impl PyBCELoss {
116 #[new]
117 fn new(reduction: Option<String>) -> (Self, PyModule) {
118 let reduction = reduction.unwrap_or_else(|| "mean".to_string());
119 (
120 Self {
121 reduction,
122 training: true,
123 },
124 PyModule::new(),
125 )
126 }
127
128 fn forward(&self, input: &PyTensor, target: &PyTensor) -> PyResult<PyTensor> {
130 let result = py_result!(input.tensor.sub(&target.tensor))?;
132 Ok(PyTensor { tensor: result })
133 }
134
135 fn __repr__(&self) -> String {
137 format!("BCELoss(reduction='{}')", self.reduction)
138 }
139
140 fn train(&mut self, mode: Option<bool>) {
142 self.training = mode.unwrap_or(true);
143 }
144
145 fn eval(&mut self) {
147 self.training = false;
148 }
149
150 fn training(&self) -> bool {
152 self.training
153 }
154}