torsh_ffi/python/
functional.rs

1//! Functional operations for PyTorch-style API
2
3// Framework infrastructure - components designed for future use
4#![allow(dead_code)]
5use crate::error::FfiError;
6use crate::python::tensor::PyTensor;
7use pyo3::prelude::*;
8
9/// ReLU activation function
10#[pyfunction]
11#[pyo3(signature = (input, inplace=false))]
12pub fn relu(input: &PyTensor, inplace: bool) -> PyResult<PyTensor> {
13    let result_data: Vec<f32> = input.data.iter().map(|&x| x.max(0.0)).collect();
14
15    if inplace {
16        // In a real implementation, this would modify the input tensor in-place
17        // For now, we'll return a new tensor
18    }
19
20    Python::attach(|py| {
21        let data = pyo3::types::PyList::new(py, &result_data)?;
22        PyTensor::new(
23            data.as_ref(),
24            Some(input.shape()),
25            Some("f32"),
26            input.requires_grad,
27        )
28    })
29}
30
31/// Sigmoid activation function
32#[pyfunction]
33pub fn sigmoid(input: &PyTensor) -> PyResult<PyTensor> {
34    let result_data: Vec<f32> = input
35        .data
36        .iter()
37        .map(|&x| 1.0 / (1.0 + (-x).exp()))
38        .collect();
39
40    Python::attach(|py| {
41        let data = pyo3::types::PyList::new(py, &result_data)?;
42        PyTensor::new(
43            data.as_ref(),
44            Some(input.shape()),
45            Some("f32"),
46            input.requires_grad,
47        )
48    })
49}
50
51/// Tanh activation function
52#[pyfunction]
53pub fn tanh(input: &PyTensor) -> PyResult<PyTensor> {
54    let result_data: Vec<f32> = input.data.iter().map(|&x| x.tanh()).collect();
55
56    Python::attach(|py| {
57        let data = pyo3::types::PyList::new(py, &result_data)?;
58        PyTensor::new(
59            data.as_ref(),
60            Some(input.shape()),
61            Some("f32"),
62            input.requires_grad,
63        )
64    })
65}
66
67/// GELU activation function (Gaussian Error Linear Unit)
68#[pyfunction]
69pub fn gelu(input: &PyTensor) -> PyResult<PyTensor> {
70    let result_data: Vec<f32> = input
71        .data
72        .iter()
73        .map(|&x| {
74            // GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
75            let sqrt_2_over_pi = (2.0 / std::f32::consts::PI).sqrt();
76            let inner = sqrt_2_over_pi * (x + 0.044715 * x.powi(3));
77            0.5 * x * (1.0 + inner.tanh())
78        })
79        .collect();
80
81    Python::attach(|py| {
82        let data = pyo3::types::PyList::new(py, &result_data)?;
83        PyTensor::new(
84            data.as_ref(),
85            Some(input.shape()),
86            Some("f32"),
87            input.requires_grad,
88        )
89    })
90}
91
92/// Softmax function
93#[pyfunction]
94#[pyo3(signature = (input, _dim=-1))]
95pub fn softmax(input: &PyTensor, _dim: i32) -> PyResult<PyTensor> {
96    if input.shape().len() != 2 {
97        return Err(FfiError::UnsupportedOperation {
98            operation: "Softmax currently only supports 2D tensors".to_string(),
99        }
100        .into());
101    }
102
103    let batch_size = input.shape()[0];
104    let features = input.shape()[1];
105    let mut result_data = vec![0.0; input.data.len()];
106
107    // Apply softmax along the last dimension (dim=-1)
108    for batch_idx in 0..batch_size {
109        let start_idx = batch_idx * features;
110        let end_idx = start_idx + features;
111        let batch_slice = &input.data[start_idx..end_idx];
112
113        // Find max for numerical stability
114        let max_val = batch_slice.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
115
116        // Compute exponentials and sum
117        let mut sum = 0.0;
118        for i in 0..features {
119            let exp_val = (batch_slice[i] - max_val).exp();
120            result_data[start_idx + i] = exp_val;
121            sum += exp_val;
122        }
123
124        // Normalize
125        for i in 0..features {
126            result_data[start_idx + i] /= sum;
127        }
128    }
129
130    Python::attach(|py| {
131        let data = pyo3::types::PyList::new(py, &result_data)?;
132        PyTensor::new(
133            data.as_ref(),
134            Some(input.shape()),
135            Some("f32"),
136            input.requires_grad,
137        )
138    })
139}
140
141/// Log softmax function
142#[pyfunction]
143#[pyo3(signature = (input, dim=-1))]
144pub fn log_softmax(input: &PyTensor, dim: i32) -> PyResult<PyTensor> {
145    let softmax_result = softmax(input, dim)?;
146
147    let result_data: Vec<f32> = softmax_result.data.iter().map(|&x| x.ln()).collect();
148
149    Python::attach(|py| {
150        let data = pyo3::types::PyList::new(py, &result_data)?;
151        PyTensor::new(
152            data.as_ref(),
153            Some(input.shape()),
154            Some("f32"),
155            input.requires_grad,
156        )
157    })
158}
159
160/// Cross entropy loss
161#[pyfunction]
162#[pyo3(signature = (input, target, reduction="mean"))]
163pub fn cross_entropy(input: &PyTensor, target: &PyTensor, reduction: &str) -> PyResult<PyTensor> {
164    if input.shape().len() != 2 || target.shape().len() != 1 {
165        return Err(FfiError::ShapeMismatch {
166            expected: vec![0, 0], // Placeholder
167            actual: vec![input.shape().len(), target.shape().len()],
168        }
169        .into());
170    }
171
172    let batch_size = input.shape()[0];
173    let num_classes = input.shape()[1];
174
175    if target.shape()[0] != batch_size {
176        return Err(FfiError::ShapeMismatch {
177            expected: vec![batch_size],
178            actual: target.shape(),
179        }
180        .into());
181    }
182
183    // Apply log_softmax to input
184    let log_probs = log_softmax(input, -1)?;
185
186    let mut losses = Vec::new();
187
188    // Compute negative log likelihood for each sample
189    for batch_idx in 0..batch_size {
190        let target_class = target.data[batch_idx] as usize;
191        if target_class >= num_classes {
192            return Err(FfiError::InvalidParameter {
193                parameter: "target".to_string(),
194                value: format!("class {} >= num_classes {}", target_class, num_classes),
195            }
196            .into());
197        }
198
199        let log_prob = log_probs.data[batch_idx * num_classes + target_class];
200        losses.push(-log_prob);
201    }
202
203    let result = match reduction {
204        "mean" => {
205            let mean_loss = losses.iter().sum::<f32>() / losses.len() as f32;
206            vec![mean_loss]
207        }
208        "sum" => {
209            let sum_loss = losses.iter().sum::<f32>();
210            vec![sum_loss]
211        }
212        "none" => losses,
213        _ => {
214            return Err(FfiError::InvalidParameter {
215                parameter: "reduction".to_string(),
216                value: reduction.to_string(),
217            }
218            .into())
219        }
220    };
221
222    Python::attach(|py| {
223        let data = pyo3::types::PyList::new(py, &result)?;
224        let shape = if reduction == "none" {
225            vec![batch_size]
226        } else {
227            vec![] // Scalar
228        };
229        PyTensor::new(
230            data.as_ref(),
231            Some(shape),
232            Some("f32"),
233            input.requires_grad || target.requires_grad,
234        )
235    })
236}
237
238/// Mean squared error loss
239#[pyfunction]
240#[pyo3(signature = (input, target, reduction="mean"))]
241pub fn mse_loss(input: &PyTensor, target: &PyTensor, reduction: &str) -> PyResult<PyTensor> {
242    if input.shape() != target.shape() {
243        return Err(FfiError::ShapeMismatch {
244            expected: input.shape(),
245            actual: target.shape(),
246        }
247        .into());
248    }
249
250    let squared_errors: Vec<f32> = input
251        .data
252        .iter()
253        .zip(target.data.iter())
254        .map(|(&x, &y)| (x - y).powi(2))
255        .collect();
256
257    let result = match reduction {
258        "mean" => {
259            let mean_loss = squared_errors.iter().sum::<f32>() / squared_errors.len() as f32;
260            vec![mean_loss]
261        }
262        "sum" => {
263            let sum_loss = squared_errors.iter().sum::<f32>();
264            vec![sum_loss]
265        }
266        "none" => squared_errors,
267        _ => {
268            return Err(FfiError::InvalidParameter {
269                parameter: "reduction".to_string(),
270                value: reduction.to_string(),
271            }
272            .into())
273        }
274    };
275
276    Python::attach(|py| {
277        let data = pyo3::types::PyList::new(py, &result)?;
278        let shape = if reduction == "none" {
279            input.shape()
280        } else {
281            vec![] // Scalar
282        };
283        PyTensor::new(
284            data.as_ref(),
285            Some(shape),
286            Some("f32"),
287            input.requires_grad || target.requires_grad,
288        )
289    })
290}
291
292/// Binary cross entropy loss
293#[pyfunction]
294#[pyo3(signature = (input, target, _weight=None, reduction="mean"))]
295pub fn binary_cross_entropy(
296    input: &PyTensor,
297    target: &PyTensor,
298    _weight: Option<&PyTensor>,
299    reduction: &str,
300) -> PyResult<PyTensor> {
301    if input.shape() != target.shape() {
302        return Err(FfiError::ShapeMismatch {
303            expected: input.shape(),
304            actual: target.shape(),
305        }
306        .into());
307    }
308
309    let losses: Vec<f32> = input
310        .data
311        .iter()
312        .zip(target.data.iter())
313        .map(|(&pred, &target)| {
314            // BCE loss: -[target * log(pred) + (1 - target) * log(1 - pred)]
315            let pred_clamped = pred.clamp(1e-7, 1.0 - 1e-7); // Numerical stability
316            -(target * pred_clamped.ln() + (1.0 - target) * (1.0 - pred_clamped).ln())
317        })
318        .collect();
319
320    let result = match reduction {
321        "mean" => {
322            let mean_loss = losses.iter().sum::<f32>() / losses.len() as f32;
323            vec![mean_loss]
324        }
325        "sum" => {
326            let sum_loss = losses.iter().sum::<f32>();
327            vec![sum_loss]
328        }
329        "none" => losses,
330        _ => {
331            return Err(FfiError::InvalidParameter {
332                parameter: "reduction".to_string(),
333                value: reduction.to_string(),
334            }
335            .into())
336        }
337    };
338
339    Python::attach(|py| {
340        let data = pyo3::types::PyList::new(py, &result)?;
341        let shape = if reduction == "none" {
342            input.shape()
343        } else {
344            vec![] // Scalar
345        };
346        PyTensor::new(
347            data.as_ref(),
348            Some(shape),
349            Some("f32"),
350            input.requires_grad || target.requires_grad,
351        )
352    })
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use pyo3::types::PyList;
359    use pyo3::Python;
360
361    #[test]
362    fn test_relu() {
363        Python::initialize();
364        Python::attach(|py| -> PyResult<()> {
365            let data = PyList::new(py, vec![-1.0, 0.0, 1.0, 2.0])?;
366            let input = PyTensor::new(data.as_ref(), None, None, false).unwrap();
367
368            let output = relu(&input, false).unwrap();
369            assert_eq!(output.data, vec![0.0, 0.0, 1.0, 2.0]);
370            Ok(())
371        })
372        .unwrap();
373    }
374
375    #[test]
376    fn test_sigmoid() {
377        Python::initialize();
378        Python::attach(|py| -> PyResult<()> {
379            let data = PyList::new(py, vec![0.0])?;
380            let input = PyTensor::new(data.as_ref(), None, None, false).unwrap();
381
382            let output = sigmoid(&input).unwrap();
383            assert!((output.data[0] - 0.5).abs() < 1e-6);
384            Ok(())
385        })
386        .unwrap();
387    }
388
389    #[test]
390    fn test_softmax() {
391        Python::initialize();
392        Python::attach(|py| -> PyResult<()> {
393            let data = PyList::new(py, vec![1.0, 2.0, 3.0])?;
394            let input = PyTensor::new(data.as_ref(), Some(vec![1, 3]), None, false).unwrap();
395
396            let output = softmax(&input, -1).unwrap();
397            let sum: f32 = output.data.iter().sum();
398            assert!((sum - 1.0).abs() < 1e-6);
399            Ok(())
400        })
401        .unwrap();
402    }
403
404    #[test]
405    fn test_mse_loss() {
406        Python::initialize();
407        Python::attach(|py| -> PyResult<()> {
408            let input_data = PyList::new(py, vec![1.0, 2.0, 3.0])?;
409            let target_data = PyList::new(py, vec![1.5, 2.5, 3.5])?;
410
411            let input = PyTensor::new(input_data.as_ref(), None, None, false).unwrap();
412            let target = PyTensor::new(target_data.as_ref(), None, None, false).unwrap();
413
414            let loss = mse_loss(&input, &target, "mean").unwrap();
415            // Expected: mean of [(1-1.5)^2, (2-2.5)^2, (3-3.5)^2] = mean of [0.25, 0.25, 0.25] = 0.25
416            assert!((loss.data[0] - 0.25).abs() < 1e-6);
417            Ok(())
418        })
419        .unwrap();
420    }
421}