1#![allow(dead_code)]
5use crate::error::FfiError;
6use crate::python::tensor::PyTensor;
7use pyo3::prelude::*;
8
9#[pyfunction]
11#[pyo3(signature = (input, inplace=false))]
12pub fn relu(input: &PyTensor, inplace: bool) -> PyResult<PyTensor> {
13 let result_data: Vec<f32> = input.data.iter().map(|&x| x.max(0.0)).collect();
14
15 if inplace {
16 }
19
20 Python::attach(|py| {
21 let data = pyo3::types::PyList::new(py, &result_data)?;
22 PyTensor::new(
23 data.as_ref(),
24 Some(input.shape()),
25 Some("f32"),
26 input.requires_grad,
27 )
28 })
29}
30
31#[pyfunction]
33pub fn sigmoid(input: &PyTensor) -> PyResult<PyTensor> {
34 let result_data: Vec<f32> = input
35 .data
36 .iter()
37 .map(|&x| 1.0 / (1.0 + (-x).exp()))
38 .collect();
39
40 Python::attach(|py| {
41 let data = pyo3::types::PyList::new(py, &result_data)?;
42 PyTensor::new(
43 data.as_ref(),
44 Some(input.shape()),
45 Some("f32"),
46 input.requires_grad,
47 )
48 })
49}
50
51#[pyfunction]
53pub fn tanh(input: &PyTensor) -> PyResult<PyTensor> {
54 let result_data: Vec<f32> = input.data.iter().map(|&x| x.tanh()).collect();
55
56 Python::attach(|py| {
57 let data = pyo3::types::PyList::new(py, &result_data)?;
58 PyTensor::new(
59 data.as_ref(),
60 Some(input.shape()),
61 Some("f32"),
62 input.requires_grad,
63 )
64 })
65}
66
67#[pyfunction]
69pub fn gelu(input: &PyTensor) -> PyResult<PyTensor> {
70 let result_data: Vec<f32> = input
71 .data
72 .iter()
73 .map(|&x| {
74 let sqrt_2_over_pi = (2.0 / std::f32::consts::PI).sqrt();
76 let inner = sqrt_2_over_pi * (x + 0.044715 * x.powi(3));
77 0.5 * x * (1.0 + inner.tanh())
78 })
79 .collect();
80
81 Python::attach(|py| {
82 let data = pyo3::types::PyList::new(py, &result_data)?;
83 PyTensor::new(
84 data.as_ref(),
85 Some(input.shape()),
86 Some("f32"),
87 input.requires_grad,
88 )
89 })
90}
91
92#[pyfunction]
94#[pyo3(signature = (input, _dim=-1))]
95pub fn softmax(input: &PyTensor, _dim: i32) -> PyResult<PyTensor> {
96 if input.shape().len() != 2 {
97 return Err(FfiError::UnsupportedOperation {
98 operation: "Softmax currently only supports 2D tensors".to_string(),
99 }
100 .into());
101 }
102
103 let batch_size = input.shape()[0];
104 let features = input.shape()[1];
105 let mut result_data = vec![0.0; input.data.len()];
106
107 for batch_idx in 0..batch_size {
109 let start_idx = batch_idx * features;
110 let end_idx = start_idx + features;
111 let batch_slice = &input.data[start_idx..end_idx];
112
113 let max_val = batch_slice.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
115
116 let mut sum = 0.0;
118 for i in 0..features {
119 let exp_val = (batch_slice[i] - max_val).exp();
120 result_data[start_idx + i] = exp_val;
121 sum += exp_val;
122 }
123
124 for i in 0..features {
126 result_data[start_idx + i] /= sum;
127 }
128 }
129
130 Python::attach(|py| {
131 let data = pyo3::types::PyList::new(py, &result_data)?;
132 PyTensor::new(
133 data.as_ref(),
134 Some(input.shape()),
135 Some("f32"),
136 input.requires_grad,
137 )
138 })
139}
140
141#[pyfunction]
143#[pyo3(signature = (input, dim=-1))]
144pub fn log_softmax(input: &PyTensor, dim: i32) -> PyResult<PyTensor> {
145 let softmax_result = softmax(input, dim)?;
146
147 let result_data: Vec<f32> = softmax_result.data.iter().map(|&x| x.ln()).collect();
148
149 Python::attach(|py| {
150 let data = pyo3::types::PyList::new(py, &result_data)?;
151 PyTensor::new(
152 data.as_ref(),
153 Some(input.shape()),
154 Some("f32"),
155 input.requires_grad,
156 )
157 })
158}
159
160#[pyfunction]
162#[pyo3(signature = (input, target, reduction="mean"))]
163pub fn cross_entropy(input: &PyTensor, target: &PyTensor, reduction: &str) -> PyResult<PyTensor> {
164 if input.shape().len() != 2 || target.shape().len() != 1 {
165 return Err(FfiError::ShapeMismatch {
166 expected: vec![0, 0], actual: vec![input.shape().len(), target.shape().len()],
168 }
169 .into());
170 }
171
172 let batch_size = input.shape()[0];
173 let num_classes = input.shape()[1];
174
175 if target.shape()[0] != batch_size {
176 return Err(FfiError::ShapeMismatch {
177 expected: vec![batch_size],
178 actual: target.shape(),
179 }
180 .into());
181 }
182
183 let log_probs = log_softmax(input, -1)?;
185
186 let mut losses = Vec::new();
187
188 for batch_idx in 0..batch_size {
190 let target_class = target.data[batch_idx] as usize;
191 if target_class >= num_classes {
192 return Err(FfiError::InvalidParameter {
193 parameter: "target".to_string(),
194 value: format!("class {} >= num_classes {}", target_class, num_classes),
195 }
196 .into());
197 }
198
199 let log_prob = log_probs.data[batch_idx * num_classes + target_class];
200 losses.push(-log_prob);
201 }
202
203 let result = match reduction {
204 "mean" => {
205 let mean_loss = losses.iter().sum::<f32>() / losses.len() as f32;
206 vec![mean_loss]
207 }
208 "sum" => {
209 let sum_loss = losses.iter().sum::<f32>();
210 vec![sum_loss]
211 }
212 "none" => losses,
213 _ => {
214 return Err(FfiError::InvalidParameter {
215 parameter: "reduction".to_string(),
216 value: reduction.to_string(),
217 }
218 .into())
219 }
220 };
221
222 Python::attach(|py| {
223 let data = pyo3::types::PyList::new(py, &result)?;
224 let shape = if reduction == "none" {
225 vec![batch_size]
226 } else {
227 vec![] };
229 PyTensor::new(
230 data.as_ref(),
231 Some(shape),
232 Some("f32"),
233 input.requires_grad || target.requires_grad,
234 )
235 })
236}
237
238#[pyfunction]
240#[pyo3(signature = (input, target, reduction="mean"))]
241pub fn mse_loss(input: &PyTensor, target: &PyTensor, reduction: &str) -> PyResult<PyTensor> {
242 if input.shape() != target.shape() {
243 return Err(FfiError::ShapeMismatch {
244 expected: input.shape(),
245 actual: target.shape(),
246 }
247 .into());
248 }
249
250 let squared_errors: Vec<f32> = input
251 .data
252 .iter()
253 .zip(target.data.iter())
254 .map(|(&x, &y)| (x - y).powi(2))
255 .collect();
256
257 let result = match reduction {
258 "mean" => {
259 let mean_loss = squared_errors.iter().sum::<f32>() / squared_errors.len() as f32;
260 vec![mean_loss]
261 }
262 "sum" => {
263 let sum_loss = squared_errors.iter().sum::<f32>();
264 vec![sum_loss]
265 }
266 "none" => squared_errors,
267 _ => {
268 return Err(FfiError::InvalidParameter {
269 parameter: "reduction".to_string(),
270 value: reduction.to_string(),
271 }
272 .into())
273 }
274 };
275
276 Python::attach(|py| {
277 let data = pyo3::types::PyList::new(py, &result)?;
278 let shape = if reduction == "none" {
279 input.shape()
280 } else {
281 vec![] };
283 PyTensor::new(
284 data.as_ref(),
285 Some(shape),
286 Some("f32"),
287 input.requires_grad || target.requires_grad,
288 )
289 })
290}
291
292#[pyfunction]
294#[pyo3(signature = (input, target, _weight=None, reduction="mean"))]
295pub fn binary_cross_entropy(
296 input: &PyTensor,
297 target: &PyTensor,
298 _weight: Option<&PyTensor>,
299 reduction: &str,
300) -> PyResult<PyTensor> {
301 if input.shape() != target.shape() {
302 return Err(FfiError::ShapeMismatch {
303 expected: input.shape(),
304 actual: target.shape(),
305 }
306 .into());
307 }
308
309 let losses: Vec<f32> = input
310 .data
311 .iter()
312 .zip(target.data.iter())
313 .map(|(&pred, &target)| {
314 let pred_clamped = pred.clamp(1e-7, 1.0 - 1e-7); -(target * pred_clamped.ln() + (1.0 - target) * (1.0 - pred_clamped).ln())
317 })
318 .collect();
319
320 let result = match reduction {
321 "mean" => {
322 let mean_loss = losses.iter().sum::<f32>() / losses.len() as f32;
323 vec![mean_loss]
324 }
325 "sum" => {
326 let sum_loss = losses.iter().sum::<f32>();
327 vec![sum_loss]
328 }
329 "none" => losses,
330 _ => {
331 return Err(FfiError::InvalidParameter {
332 parameter: "reduction".to_string(),
333 value: reduction.to_string(),
334 }
335 .into())
336 }
337 };
338
339 Python::attach(|py| {
340 let data = pyo3::types::PyList::new(py, &result)?;
341 let shape = if reduction == "none" {
342 input.shape()
343 } else {
344 vec![] };
346 PyTensor::new(
347 data.as_ref(),
348 Some(shape),
349 Some("f32"),
350 input.requires_grad || target.requires_grad,
351 )
352 })
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use pyo3::types::PyList;
359 use pyo3::Python;
360
361 #[test]
362 fn test_relu() {
363 Python::initialize();
364 Python::attach(|py| -> PyResult<()> {
365 let data = PyList::new(py, vec![-1.0, 0.0, 1.0, 2.0])?;
366 let input = PyTensor::new(data.as_ref(), None, None, false).unwrap();
367
368 let output = relu(&input, false).unwrap();
369 assert_eq!(output.data, vec![0.0, 0.0, 1.0, 2.0]);
370 Ok(())
371 })
372 .unwrap();
373 }
374
375 #[test]
376 fn test_sigmoid() {
377 Python::initialize();
378 Python::attach(|py| -> PyResult<()> {
379 let data = PyList::new(py, vec![0.0])?;
380 let input = PyTensor::new(data.as_ref(), None, None, false).unwrap();
381
382 let output = sigmoid(&input).unwrap();
383 assert!((output.data[0] - 0.5).abs() < 1e-6);
384 Ok(())
385 })
386 .unwrap();
387 }
388
389 #[test]
390 fn test_softmax() {
391 Python::initialize();
392 Python::attach(|py| -> PyResult<()> {
393 let data = PyList::new(py, vec![1.0, 2.0, 3.0])?;
394 let input = PyTensor::new(data.as_ref(), Some(vec![1, 3]), None, false).unwrap();
395
396 let output = softmax(&input, -1).unwrap();
397 let sum: f32 = output.data.iter().sum();
398 assert!((sum - 1.0).abs() < 1e-6);
399 Ok(())
400 })
401 .unwrap();
402 }
403
404 #[test]
405 fn test_mse_loss() {
406 Python::initialize();
407 Python::attach(|py| -> PyResult<()> {
408 let input_data = PyList::new(py, vec![1.0, 2.0, 3.0])?;
409 let target_data = PyList::new(py, vec![1.5, 2.5, 3.5])?;
410
411 let input = PyTensor::new(input_data.as_ref(), None, None, false).unwrap();
412 let target = PyTensor::new(target_data.as_ref(), None, None, false).unwrap();
413
414 let loss = mse_loss(&input, &target, "mean").unwrap();
415 assert!((loss.data[0] - 0.25).abs() < 1e-6);
417 Ok(())
418 })
419 .unwrap();
420 }
421}