Skip to main content

svod_tensor/
activation.rs

1//! Activation functions for neural networks.
2//!
3//! This module provides common activation functions used in deep learning,
4//! including relu, sigmoid, tanh, softmax, and their variants.
5
6use bon::bon;
7use snafu::ResultExt;
8use svod_ir::{ConstValue, UOp};
9
10use crate::reduce::AxisSpec;
11use crate::{Result, Tensor, error::UOpSnafu};
12
13#[bon]
14impl Tensor {
15    /// Rectified Linear Unit: `max(0, x)`.
16    ///
17    /// ReLU is one of the most common activation functions in deep learning.
18    /// It's simple, efficient, and helps mitigate the vanishing gradient problem.
19    ///
20    /// # Examples
21    /// ```ignore
22    /// let x = Tensor::from_slice(&[-2.0f32, -1.0, 0.0, 1.0, 2.0]);
23    /// let y = x.relu()?;
24    /// // y = [0.0, 0.0, 0.0, 1.0, 2.0]
25    /// ```
26    pub fn relu(&self) -> Result<Self> {
27        let zero = self.zero()?;
28        let condition = self.try_gt(&zero)?;
29        self.where_(&condition, &zero)
30    }
31
32    /// Sigmoid activation: `1 / (1 + exp(-x))`.
33    ///
34    /// Maps input to range (0, 1), commonly used for binary classification.
35    ///
36    /// # Examples
37    /// ```ignore
38    /// let x = Tensor::from_slice(&[-2.0f32, -1.0, 0.0, 1.0, 2.0]);
39    /// let y = x.sigmoid()?;
40    /// // y ≈ [0.119, 0.268, 0.5, 0.731, 0.880]
41    /// ```
42    pub fn sigmoid(&self) -> Result<Self> {
43        // sigmoid(x) = 1 / (1 + exp(-x)) = 1 / (1 + 2^(-x/ln2))
44        // Using exp2 matches Tinygrad's implementation for better hardware mapping.
45        let scale = self.broadcast_scalar(ConstValue::Float(-1.0 / std::f64::consts::LN_2))?;
46        let scaled = self.try_mul(&scale)?;
47        let exp2_val = scaled.try_exp2()?;
48        let one = exp2_val.one()?;
49        let denominator = one.try_add(&exp2_val)?;
50        let recip = Self::new(UOp::try_reciprocal(&denominator.uop()).context(UOpSnafu)?);
51        Ok(recip)
52    }
53
54    /// Hyperbolic tangent: `tanh(x)`.
55    ///
56    /// Maps input to range (-1, 1), centered at zero.
57    ///
58    /// # Examples
59    /// ```ignore
60    /// let x = Tensor::from_slice(&[-2.0f32, -1.0, 0.0, 1.0, 2.0]);
61    /// let y = x.tanh()?;
62    /// // y ≈ [-0.964, -0.762, 0.0, 0.762, 0.964]
63    /// ```
64    pub fn tanh(&self) -> Result<Self> {
65        // Check if tanh is a UOp primitive
66        // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
67        // Or: tanh(x) = 2*sigmoid(2x) - 1
68        let two = self.broadcast_scalar(ConstValue::Int(2))?;
69        let two_x = self.try_mul(&two)?;
70        let sig = two_x.sigmoid()?;
71        let two_sig = two.try_mul(&sig)?;
72        let one = sig.one()?;
73        two_sig.try_sub(&one)
74    }
75
76    /// Softmax activation: `exp(x - max(x)) / sum(exp(x - max(x)))`.
77    ///
78    /// Converts logits to probability distribution over specified axis.
79    /// Numerically stable implementation using max subtraction.
80    ///
81    /// # Arguments
82    /// * `axis` - Axis along which to compute softmax (default: -1, last axis)
83    ///
84    /// # Examples
85    /// ```ignore
86    /// let logits = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
87    /// let probs = logits.softmax(-1)?;
88    /// // sum(probs) = 1.0, probs[i] > 0 for all i
89    /// ```
90    pub fn softmax(&self, axis: impl Into<AxisSpec>) -> Result<Self> {
91        let axis = axis.into();
92
93        // softmax(x) = exp(x - max(x)) / sum(exp(x - max(x)))
94        // keepdim preserves the reduced axis as size 1 for correct broadcasting
95        let max_val = self.max_with().axes(axis.clone()).keepdim(true).call()?;
96        let shifted = self.try_sub(&max_val)?;
97        let exp_shifted = shifted.try_exp()?;
98        let sum_exp = exp_shifted.sum_with().axes(axis).keepdim(true).call()?;
99
100        exp_shifted.try_div(&sum_exp)
101    }
102
103    /// Log-softmax activation: `log(softmax(x))`.
104    ///
105    /// Numerically stable implementation: `x - max(x) - log(sum(exp(x - max(x))))`.
106    ///
107    /// More numerically stable than computing `log(softmax(x))` separately.
108    ///
109    /// # Arguments
110    /// * `axis` - Axis along which to compute log-softmax (default: -1, last axis)
111    ///
112    /// # Examples
113    /// ```ignore
114    /// let logits = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
115    /// let log_probs = logits.log_softmax(-1)?;
116    /// // More numerically stable than logits.softmax(-1)?.try_log()
117    /// ```
118    pub fn log_softmax(&self, axis: impl Into<AxisSpec>) -> Result<Self> {
119        let axis = axis.into();
120
121        // log_softmax(x) = x - max(x) - log(sum(exp(x - max(x))))
122        // keepdim preserves the reduced axis as size 1 for correct broadcasting
123        let max_val = self.max_with().axes(axis.clone()).keepdim(true).call()?;
124        let shifted = self.try_sub(&max_val)?;
125        let exp_shifted = shifted.try_exp()?;
126        let sum_exp = exp_shifted.sum_with().axes(axis).keepdim(true).call()?;
127        let log_sum_exp = sum_exp.try_log()?;
128        shifted.try_sub(&log_sum_exp)
129    }
130
131    /// Log-sum-exp: `log(sum(exp(x)))`.
132    ///
133    /// Numerically stable implementation: `max(x) + log(sum(exp(x - max(x))))`.
134    ///
135    /// # Arguments
136    /// * `axis` - Axis along which to compute logsumexp
137    ///
138    /// # Examples
139    /// ```ignore
140    /// let x = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
141    /// let lse = x.logsumexp(-1)?;
142    /// ```
143    pub fn logsumexp(&self, axis: impl Into<AxisSpec>) -> Result<Self> {
144        let axis = axis.into();
145
146        // logsumexp(x) = max(x) + log(sum(exp(x - max(x))))
147        // Use keepdim internally for correct broadcasting, then drop via max
148        let max_keepdim = self.max_with().axes(axis.clone()).keepdim(true).call()?;
149        let shifted = self.try_sub(&max_keepdim)?;
150        let exp_shifted = shifted.try_exp()?;
151        let sum_exp = exp_shifted.sum_with().axes(axis.clone()).keepdim(true).call()?;
152        let log_sum = sum_exp.try_log()?;
153        let result_keepdim = max_keepdim.try_add(&log_sum)?;
154
155        // Drop the keepdim axis — max over size-1 dim is effectively a squeeze
156        result_keepdim.max(axis)
157    }
158
159    /// GELU activation (Gaussian Error Linear Unit).
160    ///
161    /// Smooth approximation: `0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))`.
162    ///
163    /// GELU is the standard activation for Transformer models (BERT, GPT, etc.).
164    ///
165    /// # Examples
166    /// ```ignore
167    /// let x = Tensor::from_slice(&[-2.0f32, -1.0, 0.0, 1.0, 2.0]);
168    /// let y = x.gelu()?;
169    /// ```
170    pub fn gelu(&self) -> Result<Self> {
171        // gelu(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
172        // sqrt(2/π) ≈ 0.7978845608
173
174        let half = self.broadcast_scalar(ConstValue::Float(0.5))?;
175        let one = self.broadcast_scalar(ConstValue::Float(1.0))?;
176        let coef1 = self.broadcast_scalar(ConstValue::Float(0.7978845608))?;
177        let coef2 = self.broadcast_scalar(ConstValue::Float(0.044715))?;
178
179        // x^3
180        let x_squared = self.try_mul(self)?;
181        let x_cubed = x_squared.try_mul(self)?;
182
183        // 0.044715 * x^3
184        let cubic_term = coef2.try_mul(&x_cubed)?;
185
186        // x + 0.044715 * x^3
187        let inner = self.try_add(&cubic_term)?;
188
189        // sqrt(2/π) * (x + 0.044715 * x^3)
190        let scaled = coef1.try_mul(&inner)?;
191
192        // tanh(...)
193        let tanh_part = scaled.tanh()?;
194
195        // 1 + tanh(...)
196        let one_plus_tanh = one.try_add(&tanh_part)?;
197
198        // x * (1 + tanh(...))
199        let x_times = self.try_mul(&one_plus_tanh)?;
200
201        // 0.5 * x * (1 + tanh(...))
202        half.try_mul(&x_times)
203    }
204
205    /// Exact GELU: `0.5 * x * (1 + erf(x / sqrt(2)))`.
206    pub fn gelu_exact(&self) -> Result<Self> {
207        let dtype = self.uop().dtype();
208        let half = Tensor::const_(0.5f64, dtype.clone());
209        let one = Tensor::const_(1.0f64, dtype.clone());
210        let sqrt2 = Tensor::const_(std::f64::consts::SQRT_2, dtype);
211        half.try_mul(self)?.try_mul(&one.try_add(&self.try_div(&sqrt2)?.erf()?)?)
212    }
213
214    /// Hard Sigmoid: `clamp(alpha * x + beta, 0, 1)`.
215    ///
216    /// Piecewise linear approximation of sigmoid. Faster to compute.
217    ///
218    /// # Arguments
219    /// * `alpha` - Slope (default 0.2 in ONNX)
220    /// * `beta` - Offset (default 0.5 in ONNX)
221    pub fn hard_sigmoid(&self, alpha: f64, beta: f64) -> Result<Self> {
222        let alpha_t = self.broadcast_scalar(ConstValue::Float(alpha))?;
223        let beta_t = self.broadcast_scalar(ConstValue::Float(beta))?;
224        let zero = self.broadcast_scalar(ConstValue::Float(0.0))?;
225        let one = self.broadcast_scalar(ConstValue::Float(1.0))?;
226        let ax = alpha_t.try_mul(self)?;
227        let axb = ax.try_add(&beta_t)?;
228        // clamp(axb, 0, 1) = max(0, min(1, axb))
229        let clamped_low = axb.maximum(&zero)?;
230        clamped_low.minimum(&one)
231    }
232
233    /// Leaky ReLU: `x if x > 0, alpha * x otherwise`.
234    ///
235    /// # Arguments
236    /// * `alpha` - Negative slope (default 0.01 in ONNX)
237    pub fn leaky_relu(&self, alpha: f64) -> Result<Self> {
238        let zero = self.zero()?;
239        let alpha_t = self.broadcast_scalar(ConstValue::Float(alpha))?;
240        let condition = self.try_gt(&zero)?;
241        let neg_branch = alpha_t.try_mul(self)?;
242        self.where_(&condition, &neg_branch)
243    }
244
245    /// PReLU: `x if x > 0, slope * x otherwise`.
246    ///
247    /// Like LeakyReLU but with a learned per-channel slope.
248    pub fn prelu(&self, slope: &Tensor) -> Result<Self> {
249        let zero = self.zero()?;
250        let condition = self.try_gt(&zero)?;
251        let neg_branch = self.try_mul(slope)?;
252        self.where_(&condition, &neg_branch)
253    }
254
255    /// Thresholded ReLU: `x if x > alpha, 0 otherwise`.
256    ///
257    /// # Arguments
258    /// * `alpha` - Threshold (default 1.0 in ONNX)
259    pub fn thresholded_relu(&self, alpha: f64) -> Result<Self> {
260        let alpha_t = self.broadcast_scalar(ConstValue::Float(alpha))?;
261        let zero = self.zero()?;
262        let condition = self.try_gt(&alpha_t)?;
263        self.where_(&condition, &zero)
264    }
265
266    /// ELU: `x if x > 0, alpha * (exp(x) - 1) otherwise`.
267    ///
268    /// # Arguments
269    /// * `alpha` - Scale for negative part (default 1.0 in ONNX)
270    pub fn elu(&self, alpha: f64) -> Result<Self> {
271        let zero = self.zero()?;
272        let one = self.one()?;
273        let alpha_t = self.broadcast_scalar(ConstValue::Float(alpha))?;
274        let condition = self.try_gt(&zero)?;
275        let exp_minus_1 = self.try_exp()?.try_sub(&one)?;
276        let neg_branch = alpha_t.try_mul(&exp_minus_1)?;
277        self.where_(&condition, &neg_branch)
278    }
279
280    /// SELU: `gamma * (alpha * exp(x) - alpha) if x <= 0, gamma * x if x > 0`.
281    ///
282    /// Self-normalizing activation with fixed constants.
283    ///
284    /// # Arguments
285    /// * `alpha` - Default 1.6732632...
286    /// * `gamma` - Default 1.0507010...
287    pub fn selu(&self, alpha: f64, gamma: f64) -> Result<Self> {
288        let zero = self.zero()?;
289        let alpha_t = self.broadcast_scalar(ConstValue::Float(alpha))?;
290        let gamma_t = self.broadcast_scalar(ConstValue::Float(gamma))?;
291        let condition = self.try_ge(&zero)?;
292        // neg: alpha * exp(x) - alpha
293        let neg_branch =
294            alpha_t.try_mul(&self.try_exp()?)?.try_sub(&self.broadcast_scalar(ConstValue::Float(alpha))?)?;
295        let selected = self.where_(&condition, &neg_branch)?;
296        gamma_t.try_mul(&selected)
297    }
298
299    /// Swish/SiLU activation: `x * sigmoid(x)`.
300    ///
301    /// Also known as SiLU (Sigmoid Linear Unit).
302    /// Used in modern CNN architectures and some Transformers.
303    ///
304    /// # Examples
305    /// ```ignore
306    /// let x = Tensor::from_slice(&[-2.0f32, -1.0, 0.0, 1.0, 2.0]);
307    /// let y = x.swish()?;
308    /// ```
309    pub fn swish(&self) -> Result<Self> {
310        // swish(x) = x * sigmoid(x)
311        let sig = self.sigmoid()?;
312        self.try_mul(&sig)
313    }
314
315    /// Alias for `swish` (matches PyTorch naming).
316    pub fn silu(&self) -> Result<Self> {
317        self.swish()
318    }
319
320    /// Gated Linear Unit: splits `self` along `dim` into two halves,
321    /// returns `first_half * sigmoid(second_half)`.
322    pub fn glu(&self, dim: isize) -> Result<Self> {
323        let shape = self.shape()?;
324        let ndim = shape.len();
325        let axis = if dim < 0 { (ndim as isize + dim) as usize } else { dim as usize };
326        let full_size = shape[axis].as_const().expect("GLU dim must be concrete");
327        assert!(full_size % 2 == 0, "GLU dimension must be even, got {full_size}");
328        let half = full_size / 2;
329        let halves = self.split(&[half, half], dim)?;
330        let gate = halves[1].sigmoid()?;
331        halves[0].try_mul(&gate)
332    }
333
334    /// Softplus: `log(1 + exp(beta*x)) / beta`, numerically stable via logaddexp.
335    pub fn softplus(&self, beta: f64) -> Result<Self> {
336        let beta_t = self.broadcast_scalar(ConstValue::Float(beta))?;
337        let scaled = self.try_mul(&beta_t)?;
338        let zero = self.zero()?;
339        let inv_beta = self.broadcast_scalar(ConstValue::Float(1.0 / beta))?;
340        // logaddexp(scaled, 0) = max(scaled, 0) + log(exp(scaled - max) + exp(0 - max))
341        let m = scaled.maximum(&zero)?;
342        let stable = scaled.try_sub(&m)?.try_exp()?.try_add(&zero.try_sub(&m)?.try_exp()?)?.try_log()?.try_add(&m)?;
343        stable.try_mul(&inv_beta)
344    }
345
346    /// Mish: `x * tanh(softplus(x))`.
347    pub fn mish(&self) -> Result<Self> {
348        self.try_mul(&self.softplus(1.0)?.tanh()?)
349    }
350
351    /// ReLU6: `relu(x) - relu(x-6)` = `clamp(x, 0, 6)`.
352    pub fn relu6(&self) -> Result<Self> {
353        let six = self.broadcast_scalar(ConstValue::Int(6))?;
354        let relu_x = self.relu()?;
355        let relu_x6 = self.try_sub(&six)?.relu()?;
356        relu_x.try_sub(&relu_x6)
357    }
358
359    /// HardSwish: `x * relu6(x+3) / 6`.
360    pub fn hardswish(&self) -> Result<Self> {
361        let three = self.broadcast_scalar(ConstValue::Int(3))?;
362        let six = self.broadcast_scalar(ConstValue::Int(6))?;
363        let r6 = self.try_add(&three)?.relu6()?;
364        self.try_mul(&r6)?.try_div(&six)
365    }
366
367    /// Softsign: `x / (1 + |x|)`.
368    pub fn softsign(&self) -> Result<Self> {
369        let one = self.one()?;
370        let denom = one.try_add(&self.try_abs()?)?;
371        self.try_div(&denom)
372    }
373
374    /// CELU: `max(0, x) + min(0, alpha*(exp(x/alpha)-1))`.
375    pub fn celu(&self, alpha: f64) -> Result<Self> {
376        let zero = self.zero()?;
377        let one = self.one()?;
378        let alpha_t = self.broadcast_scalar(ConstValue::Float(alpha))?;
379        let pos = self.maximum(&zero)?;
380        let neg = alpha_t.try_mul(&self.try_div(&alpha_t)?.try_exp()?.try_sub(&one)?)?.minimum(&zero)?;
381        pos.try_add(&neg)
382    }
383
384    /// Batch Normalization.
385    ///
386    /// Applies: `y = scale * (x - mean) * invstd + bias`
387    /// where `invstd = 1 / sqrt(var + epsilon)`
388    ///
389    /// This is the inference mode batchnorm (no running stats update).
390    /// The caller provides pre-computed mean and inverse standard deviation.
391    ///
392    /// # Arguments
393    /// * `scale` - Gamma/weight parameter (optional, defaults to 1)
394    /// * `bias` - Beta parameter (optional, defaults to 0)
395    /// * `mean` - Running mean
396    /// * `invstd` - Inverse standard deviation (1 / sqrt(var + eps))
397    /// * `axis` - Axis/axes to normalize over (default: 1 for NCHW)
398    ///
399    /// # Examples
400    /// ```ignore
401    /// let x = Tensor::randn(&[8, 4, 16, 16]);
402    /// let mean = x.mean(AxisSpec::Multiple(vec![0, 2, 3]))?;
403    /// let var = x.var(AxisSpec::Multiple(vec![0, 2, 3]))?;
404    /// let eps = Tensor::from_slice([1e-5]);
405    /// let invstd = var.try_add(&eps)?.try_rsqrt()?;
406    /// let normalized = x.batchnorm().mean(&mean).invstd(&invstd).call()?;
407    /// ```
408    #[builder]
409    pub fn batchnorm(
410        &self,
411        scale: Option<&Tensor>,
412        bias: Option<&Tensor>,
413        mean: &Tensor,
414        invstd: &Tensor,
415        #[builder(default = AxisSpec::Single(1))] axis: AxisSpec,
416    ) -> Result<Self> {
417        let shape = self.shape()?;
418
419        // Build broadcast shape: keep axis dimensions, others become 1
420        let axis_indices: std::collections::HashSet<usize> = match &axis {
421            AxisSpec::All => (0..shape.len()).collect(),
422            AxisSpec::Single(a) => {
423                let a = if *a < 0 { (shape.len() as isize + *a) as usize } else { *a as usize };
424                std::iter::once(a).collect()
425            }
426            AxisSpec::Multiple(axes) => {
427                axes.iter().map(|&a| if a < 0 { (shape.len() as isize + a) as usize } else { a as usize }).collect()
428            }
429        };
430
431        let broadcast_shape: Vec<isize> = shape
432            .iter()
433            .enumerate()
434            .map(|(i, dim)| if axis_indices.contains(&i) { dim.as_const().unwrap_or(1) as isize } else { 1 })
435            .collect();
436
437        // x - mean (reshape mean to broadcast shape, like Tinygrad does)
438        let mean_reshaped = mean.try_reshape(&broadcast_shape)?;
439        let x_centered = self.try_sub(&mean_reshaped)?;
440
441        // (x - mean) * invstd
442        let invstd_reshaped = invstd.try_reshape(&broadcast_shape)?;
443        let mut result = x_centered.try_mul(&invstd_reshaped)?;
444
445        // scale * (x - mean) * invstd
446        if let Some(w) = scale {
447            let w_reshaped = w.try_reshape(&broadcast_shape)?;
448            result = result.try_mul(&w_reshaped)?;
449        }
450
451        // scale * (x - mean) * invstd + bias
452        if let Some(b) = bias {
453            let b_reshaped = b.try_reshape(&broadcast_shape)?;
454            result = result.try_add(&b_reshaped)?;
455        }
456
457        Ok(result)
458    }
459}