train_station/tensor/ops/
sigmoid.rs

1//! Sigmoid activation function
2//!
3//! Provides the sigmoid activation function following PyTorch conventions with
4//! comprehensive automatic differentiation support and numerically stable computation.
5//!
6//! # Key Features
7//!
8//! - **Sigmoid Activation**: `sigmoid()` - Computes 1/(1+e^(-x)) for each element (PyTorch `sigmoid()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **Numerical Stability**: Avoids overflow for large positive/negative values
11//! - **Mathematical Accuracy**: High-precision sigmoid computation
12//! - **Range Guarantee**: Output values always in range (0, 1)
13//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
14//!
15//! # Mathematical Properties
16//!
17//! The sigmoid activation function has the following properties:
18//! - **Definition**: f(x) = 1 / (1 + e^(-x))
19//! - **Range**: (0, 1) - outputs are always between 0 and 1
20//! - **Symmetry**: f(-x) = 1 - f(x) for all x
21//! - **Monotonicity**: Strictly increasing function
22//! - **Continuity**: Continuous and differentiable everywhere
23//! - **Gradient**: f'(x) = f(x) * (1 - f(x)) = sigmoid(x) * (1 - sigmoid(x))
24//! - **Limits**: lim(x→-∞) f(x) = 0, lim(x→+∞) f(x) = 1
25//!
26//! # Performance Characteristics
27//!
28//! - **Numerical Stability**: Avoids overflow using stable implementation
29//! - **Scalar Implementation**: Optimized scalar computation for mathematical accuracy
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Mathematical Accuracy**: High-precision exponential and division operations
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37impl Tensor {
38    /// Element-wise sigmoid activation function
39    ///
40    /// Computes the sigmoid function for each element: `output[i] = 1 / (1 + e^(-self[i]))`
41    ///
42    /// Uses a numerically stable implementation that avoids overflow for large positive/negative
43    /// values by using different computation paths for positive and negative inputs.
44    ///
45    /// # Returns
46    ///
47    /// A new tensor with sigmoid applied to each element, values in range (0, 1)
48    ///
49    /// # Performance Characteristics
50    ///
51    /// - **Numerical Stability**: Avoids overflow using stable implementation
52    /// - **Scalar Implementation**: Optimized scalar computation for mathematical accuracy
53    /// - **Cache-friendly**: Linear memory access patterns
54    /// - **Mathematical Accuracy**: High-precision exponential and division operations
55    /// - **Gradient Tracking**: Full gradtrack support with efficient gradient computation
56    ///
57    /// # Implementation Details
58    ///
59    /// Uses a numerically stable implementation:
60    /// - For x ≥ 0: computes 1 / (1 + e^(-x)) to avoid overflow in e^x for large positive x
61    /// - For x < 0: computes e^x / (1 + e^x) to avoid overflow in e^(-x) for large negative x
62    ///   This ensures the result is always in the range (0, 1) without numerical overflow.
63    ///
64    /// # Examples
65    ///
66    /// ## Basic Sigmoid Activation
67    ///
68    /// ```
69    /// use train_station::Tensor;
70    ///
71    /// let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
72    /// let b = a.sigmoid();
73    /// assert_eq!(b.shape().dims, vec![3]);
74    /// assert!((b.get(&[0]) - 0.26894143).abs() < 1e-6); // sigmoid(-1.0)
75    /// assert!((b.get(&[1]) - 0.5).abs() < 1e-6); // sigmoid(0.0)
76    /// assert!((b.get(&[2]) - 0.7310586).abs() < 1e-6); // sigmoid(1.0)
77    /// ```
78    ///
79    /// ## Extreme Values
80    ///
81    /// ```
82    /// use train_station::Tensor;
83    ///
84    /// let a = Tensor::from_slice(&[-10.0, 10.0], vec![2]).unwrap();
85    /// let b = a.sigmoid();
86    /// assert_eq!(b.shape().dims, vec![2]);
87    /// assert!(b.get(&[0]) < 1e-4); // sigmoid(-10.0) ≈ 0
88    /// assert!(b.get(&[1]) > 0.9999); // sigmoid(10.0) ≈ 1
89    /// ```
90    pub fn sigmoid(&self) -> Tensor {
91        let mut out = Tensor::new(self.shape().dims.clone());
92        unsafe {
93            let src = self.as_ptr();
94            let dst = out.as_mut_ptr();
95            let n = self.size();
96            for i in 0..n {
97                let x = *src.add(i);
98                // Stable sigmoid
99                let y = if x >= 0.0 {
100                    let z = (-x).exp();
101                    1.0 / (1.0 + z)
102                } else {
103                    let z = x.exp();
104                    z / (1.0 + z)
105                };
106                *dst.add(i) = y;
107            }
108        }
109
110        if self.requires_grad() && is_grad_enabled() {
111            let mut result = out.clone();
112            result.set_requires_grad_internal(true);
113            let grad_fn = GradFn::Sigmoid {
114                saved_output: Box::new(out.clone()),
115            };
116            result.set_grad_fn(grad_fn.clone());
117            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
118            return result;
119        }
120
121        out
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_sigmoid_forward_basic() {
131        let x = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
132        let y = x.sigmoid();
133        unsafe {
134            assert!((*y.as_ptr() - 0.26894143).abs() < 1e-6);
135            assert!((*y.as_ptr().add(1) - 0.5).abs() < 1e-6);
136            assert!((*y.as_ptr().add(2) - 0.7310586).abs() < 1e-6);
137        }
138    }
139}