train_station/tensor/ops/
sigmoid.rs

1//! Sigmoid activation function
2//!
3//! Provides the sigmoid activation function following PyTorch conventions with
4//! comprehensive automatic differentiation support and numerically stable computation.
5//!
6//! # Key Features
7//!
8//! - **Sigmoid Activation**: `sigmoid()` - Computes 1/(1+e^(-x)) for each element (PyTorch `sigmoid()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with efficient gradient computation
10//! - **Numerical Stability**: Avoids overflow for large positive/negative values
11//! - **Mathematical Accuracy**: High-precision sigmoid computation
12//! - **Range Guarantee**: Output values always in range (0, 1)
13//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
14//!
15//! # Mathematical Properties
16//!
17//! The sigmoid activation function has the following properties:
18//! - **Definition**: f(x) = 1 / (1 + e^(-x))
19//! - **Range**: (0, 1) - outputs are always between 0 and 1
20//! - **Symmetry**: f(-x) = 1 - f(x) for all x
21//! - **Monotonicity**: Strictly increasing function
22//! - **Continuity**: Continuous and differentiable everywhere
23//! - **Gradient**: f'(x) = f(x) * (1 - f(x)) = sigmoid(x) * (1 - sigmoid(x))
24//! - **Limits**: lim(x→-∞) f(x) = 0, lim(x→+∞) f(x) = 1
25//!
26//! # Performance Characteristics
27//!
28//! - **Numerical Stability**: Avoids overflow using stable implementation
29//! - **Scalar Implementation**: Optimized scalar computation for mathematical accuracy
30//! - **Cache-friendly Access**: Linear memory access patterns
31//! - **Mathematical Accuracy**: High-precision exponential and division operations
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37impl Tensor {
38    /// Element-wise sigmoid activation function
39    ///
40    /// Computes the sigmoid function for each element: `output[i] = 1 / (1 + e^(-self[i]))`
41    ///
42    /// Uses a numerically stable implementation that avoids overflow for large positive/negative
43    /// values by using different computation paths for positive and negative inputs.
44    ///
45    /// # Returns
46    ///
47    /// A new tensor with sigmoid applied to each element, values in range (0, 1)
48    ///
49    /// # Performance Characteristics
50    ///
51    /// - **Numerical Stability**: Avoids overflow using stable implementation
52    /// - **Scalar Implementation**: Optimized scalar computation for mathematical accuracy
53    /// - **Cache-friendly**: Linear memory access patterns
54    /// - **Mathematical Accuracy**: High-precision exponential and division operations
55    /// - **Gradient Tracking**: Full gradtrack support with efficient gradient computation
56    ///
57    /// # Implementation Details
58    ///
59    /// Uses a numerically stable implementation:
60    /// - For x ≥ 0: computes 1 / (1 + e^(-x)) to avoid overflow in e^x for large positive x
61    /// - For x < 0: computes e^x / (1 + e^x) to avoid overflow in e^(-x) for large negative x
62    ///   This ensures the result is always in the range (0, 1) without numerical overflow.
63    ///
64    /// # Examples
65    ///
66    /// ## Basic Sigmoid Activation
67    ///
68    /// ```
69    /// use train_station::Tensor;
70    ///
71    /// let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
72    /// let b = a.sigmoid();
73    /// assert_eq!(b.shape().dims, vec![3]);
74    /// assert!((b.get(&[0]) - 0.26894143).abs() < 1e-6); // sigmoid(-1.0)
75    /// assert!((b.get(&[1]) - 0.5).abs() < 1e-6); // sigmoid(0.0)
76    /// assert!((b.get(&[2]) - 0.7310586).abs() < 1e-6); // sigmoid(1.0)
77    /// ```
78    ///
79    /// ## Extreme Values
80    ///
81    /// ```
82    /// use train_station::Tensor;
83    ///
84    /// let a = Tensor::from_slice(&[-10.0, 10.0], vec![2]).unwrap();
85    /// let b = a.sigmoid();
86    /// assert_eq!(b.shape().dims, vec![2]);
87    /// assert!(b.get(&[0]) < 1e-4); // sigmoid(-10.0) ≈ 0
88    /// assert!(b.get(&[1]) > 0.9999); // sigmoid(10.0) ≈ 1
89    /// ```
90    #[track_caller]
91    pub fn sigmoid(&self) -> Tensor {
92        let mut out = Tensor::new(self.shape().dims.clone());
93        unsafe {
94            let src = self.as_ptr();
95            let dst = out.as_mut_ptr();
96            let n = self.size();
97            for i in 0..n {
98                let x = *src.add(i);
99                // Stable sigmoid
100                let y = if x >= 0.0 {
101                    let z = (-x).exp();
102                    1.0 / (1.0 + z)
103                } else {
104                    let z = x.exp();
105                    z / (1.0 + z)
106                };
107                *dst.add(i) = y;
108            }
109        }
110
111        if self.requires_grad() && is_grad_enabled() {
112            let mut result = out.clone();
113            result.set_requires_grad_internal(true);
114            let grad_fn = GradFn::Sigmoid {
115                saved_output: Box::new(out.clone()),
116            };
117            result.set_grad_fn(grad_fn.clone());
118            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
119            return result;
120        }
121
122        out
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_sigmoid_forward_basic() {
132        let x = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
133        let y = x.sigmoid();
134        unsafe {
135            assert!((*y.as_ptr() - 0.26894143).abs() < 1e-6);
136            assert!((*y.as_ptr().add(1) - 0.5).abs() < 1e-6);
137            assert!((*y.as_ptr().add(2) - 0.7310586).abs() < 1e-6);
138        }
139    }
140}