svod_tensor/activation.rs
1//! Activation functions for neural networks.
2//!
3//! This module provides common activation functions used in deep learning,
4//! including relu, sigmoid, tanh, softmax, and their variants.
5
6use bon::bon;
7use snafu::ResultExt;
8use svod_ir::{ConstValue, UOp};
9
10use crate::reduce::AxisSpec;
11use crate::{Result, Tensor, error::UOpSnafu};
12
13#[bon]
14impl Tensor {
15 /// Rectified Linear Unit: `max(0, x)`.
16 ///
17 /// ReLU is one of the most common activation functions in deep learning.
18 /// It's simple, efficient, and helps mitigate the vanishing gradient problem.
19 ///
20 /// # Examples
21 /// ```ignore
22 /// let x = Tensor::from_slice(&[-2.0f32, -1.0, 0.0, 1.0, 2.0]);
23 /// let y = x.relu()?;
24 /// // y = [0.0, 0.0, 0.0, 1.0, 2.0]
25 /// ```
26 pub fn relu(&self) -> Result<Self> {
27 let zero = self.zero()?;
28 let condition = self.try_gt(&zero)?;
29 self.where_(&condition, &zero)
30 }
31
32 /// Sigmoid activation: `1 / (1 + exp(-x))`.
33 ///
34 /// Maps input to range (0, 1), commonly used for binary classification.
35 ///
36 /// # Examples
37 /// ```ignore
38 /// let x = Tensor::from_slice(&[-2.0f32, -1.0, 0.0, 1.0, 2.0]);
39 /// let y = x.sigmoid()?;
40 /// // y ≈ [0.119, 0.268, 0.5, 0.731, 0.880]
41 /// ```
42 pub fn sigmoid(&self) -> Result<Self> {
43 // sigmoid(x) = 1 / (1 + exp(-x)) = 1 / (1 + 2^(-x/ln2))
44 // Using exp2 matches Tinygrad's implementation for better hardware mapping.
45 let scale = self.broadcast_scalar(ConstValue::Float(-1.0 / std::f64::consts::LN_2))?;
46 let scaled = self.try_mul(&scale)?;
47 let exp2_val = scaled.try_exp2()?;
48 let one = exp2_val.one()?;
49 let denominator = one.try_add(&exp2_val)?;
50 let recip = Self::new(UOp::try_reciprocal(&denominator.uop()).context(UOpSnafu)?);
51 Ok(recip)
52 }
53
54 /// Hyperbolic tangent: `tanh(x)`.
55 ///
56 /// Maps input to range (-1, 1), centered at zero.
57 ///
58 /// # Examples
59 /// ```ignore
60 /// let x = Tensor::from_slice(&[-2.0f32, -1.0, 0.0, 1.0, 2.0]);
61 /// let y = x.tanh()?;
62 /// // y ≈ [-0.964, -0.762, 0.0, 0.762, 0.964]
63 /// ```
64 pub fn tanh(&self) -> Result<Self> {
65 // Check if tanh is a UOp primitive
66 // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
67 // Or: tanh(x) = 2*sigmoid(2x) - 1
68 let two = self.broadcast_scalar(ConstValue::Int(2))?;
69 let two_x = self.try_mul(&two)?;
70 let sig = two_x.sigmoid()?;
71 let two_sig = two.try_mul(&sig)?;
72 let one = sig.one()?;
73 two_sig.try_sub(&one)
74 }
75
76 /// Softmax activation: `exp(x - max(x)) / sum(exp(x - max(x)))`.
77 ///
78 /// Converts logits to probability distribution over specified axis.
79 /// Numerically stable implementation using max subtraction.
80 ///
81 /// # Arguments
82 /// * `axis` - Axis along which to compute softmax (default: -1, last axis)
83 ///
84 /// # Examples
85 /// ```ignore
86 /// let logits = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
87 /// let probs = logits.softmax(-1)?;
88 /// // sum(probs) = 1.0, probs[i] > 0 for all i
89 /// ```
90 pub fn softmax(&self, axis: impl Into<AxisSpec>) -> Result<Self> {
91 let axis = axis.into();
92
93 // softmax(x) = exp(x - max(x)) / sum(exp(x - max(x)))
94 // keepdim preserves the reduced axis as size 1 for correct broadcasting
95 let max_val = self.max_with().axes(axis.clone()).keepdim(true).call()?;
96 let shifted = self.try_sub(&max_val)?;
97 let exp_shifted = shifted.try_exp()?;
98 let sum_exp = exp_shifted.sum_with().axes(axis).keepdim(true).call()?;
99
100 exp_shifted.try_div(&sum_exp)
101 }
102
103 /// Log-softmax activation: `log(softmax(x))`.
104 ///
105 /// Numerically stable implementation: `x - max(x) - log(sum(exp(x - max(x))))`.
106 ///
107 /// More numerically stable than computing `log(softmax(x))` separately.
108 ///
109 /// # Arguments
110 /// * `axis` - Axis along which to compute log-softmax (default: -1, last axis)
111 ///
112 /// # Examples
113 /// ```ignore
114 /// let logits = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
115 /// let log_probs = logits.log_softmax(-1)?;
116 /// // More numerically stable than logits.softmax(-1)?.try_log()
117 /// ```
118 pub fn log_softmax(&self, axis: impl Into<AxisSpec>) -> Result<Self> {
119 let axis = axis.into();
120
121 // log_softmax(x) = x - max(x) - log(sum(exp(x - max(x))))
122 // keepdim preserves the reduced axis as size 1 for correct broadcasting
123 let max_val = self.max_with().axes(axis.clone()).keepdim(true).call()?;
124 let shifted = self.try_sub(&max_val)?;
125 let exp_shifted = shifted.try_exp()?;
126 let sum_exp = exp_shifted.sum_with().axes(axis).keepdim(true).call()?;
127 let log_sum_exp = sum_exp.try_log()?;
128 shifted.try_sub(&log_sum_exp)
129 }
130
131 /// Log-sum-exp: `log(sum(exp(x)))`.
132 ///
133 /// Numerically stable implementation: `max(x) + log(sum(exp(x - max(x))))`.
134 ///
135 /// # Arguments
136 /// * `axis` - Axis along which to compute logsumexp
137 ///
138 /// # Examples
139 /// ```ignore
140 /// let x = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
141 /// let lse = x.logsumexp(-1)?;
142 /// ```
143 pub fn logsumexp(&self, axis: impl Into<AxisSpec>) -> Result<Self> {
144 let axis = axis.into();
145
146 // logsumexp(x) = max(x) + log(sum(exp(x - max(x))))
147 // Use keepdim internally for correct broadcasting, then drop via max
148 let max_keepdim = self.max_with().axes(axis.clone()).keepdim(true).call()?;
149 let shifted = self.try_sub(&max_keepdim)?;
150 let exp_shifted = shifted.try_exp()?;
151 let sum_exp = exp_shifted.sum_with().axes(axis.clone()).keepdim(true).call()?;
152 let log_sum = sum_exp.try_log()?;
153 let result_keepdim = max_keepdim.try_add(&log_sum)?;
154
155 // Drop the keepdim axis — max over size-1 dim is effectively a squeeze
156 result_keepdim.max(axis)
157 }
158
159 /// GELU activation (Gaussian Error Linear Unit).
160 ///
161 /// Smooth approximation: `0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))`.
162 ///
163 /// GELU is the standard activation for Transformer models (BERT, GPT, etc.).
164 ///
165 /// # Examples
166 /// ```ignore
167 /// let x = Tensor::from_slice(&[-2.0f32, -1.0, 0.0, 1.0, 2.0]);
168 /// let y = x.gelu()?;
169 /// ```
170 pub fn gelu(&self) -> Result<Self> {
171 // gelu(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
172 // sqrt(2/π) ≈ 0.7978845608
173
174 let half = self.broadcast_scalar(ConstValue::Float(0.5))?;
175 let one = self.broadcast_scalar(ConstValue::Float(1.0))?;
176 let coef1 = self.broadcast_scalar(ConstValue::Float(0.7978845608))?;
177 let coef2 = self.broadcast_scalar(ConstValue::Float(0.044715))?;
178
179 // x^3
180 let x_squared = self.try_mul(self)?;
181 let x_cubed = x_squared.try_mul(self)?;
182
183 // 0.044715 * x^3
184 let cubic_term = coef2.try_mul(&x_cubed)?;
185
186 // x + 0.044715 * x^3
187 let inner = self.try_add(&cubic_term)?;
188
189 // sqrt(2/π) * (x + 0.044715 * x^3)
190 let scaled = coef1.try_mul(&inner)?;
191
192 // tanh(...)
193 let tanh_part = scaled.tanh()?;
194
195 // 1 + tanh(...)
196 let one_plus_tanh = one.try_add(&tanh_part)?;
197
198 // x * (1 + tanh(...))
199 let x_times = self.try_mul(&one_plus_tanh)?;
200
201 // 0.5 * x * (1 + tanh(...))
202 half.try_mul(&x_times)
203 }
204
205 /// Exact GELU: `0.5 * x * (1 + erf(x / sqrt(2)))`.
206 pub fn gelu_exact(&self) -> Result<Self> {
207 let dtype = self.uop().dtype();
208 let half = Tensor::const_(0.5f64, dtype.clone());
209 let one = Tensor::const_(1.0f64, dtype.clone());
210 let sqrt2 = Tensor::const_(std::f64::consts::SQRT_2, dtype);
211 half.try_mul(self)?.try_mul(&one.try_add(&self.try_div(&sqrt2)?.erf()?)?)
212 }
213
214 /// Hard Sigmoid: `clamp(alpha * x + beta, 0, 1)`.
215 ///
216 /// Piecewise linear approximation of sigmoid. Faster to compute.
217 ///
218 /// # Arguments
219 /// * `alpha` - Slope (default 0.2 in ONNX)
220 /// * `beta` - Offset (default 0.5 in ONNX)
221 pub fn hard_sigmoid(&self, alpha: f64, beta: f64) -> Result<Self> {
222 let alpha_t = self.broadcast_scalar(ConstValue::Float(alpha))?;
223 let beta_t = self.broadcast_scalar(ConstValue::Float(beta))?;
224 let zero = self.broadcast_scalar(ConstValue::Float(0.0))?;
225 let one = self.broadcast_scalar(ConstValue::Float(1.0))?;
226 let ax = alpha_t.try_mul(self)?;
227 let axb = ax.try_add(&beta_t)?;
228 // clamp(axb, 0, 1) = max(0, min(1, axb))
229 let clamped_low = axb.maximum(&zero)?;
230 clamped_low.minimum(&one)
231 }
232
233 /// Leaky ReLU: `x if x > 0, alpha * x otherwise`.
234 ///
235 /// # Arguments
236 /// * `alpha` - Negative slope (default 0.01 in ONNX)
237 pub fn leaky_relu(&self, alpha: f64) -> Result<Self> {
238 let zero = self.zero()?;
239 let alpha_t = self.broadcast_scalar(ConstValue::Float(alpha))?;
240 let condition = self.try_gt(&zero)?;
241 let neg_branch = alpha_t.try_mul(self)?;
242 self.where_(&condition, &neg_branch)
243 }
244
245 /// PReLU: `x if x > 0, slope * x otherwise`.
246 ///
247 /// Like LeakyReLU but with a learned per-channel slope.
248 pub fn prelu(&self, slope: &Tensor) -> Result<Self> {
249 let zero = self.zero()?;
250 let condition = self.try_gt(&zero)?;
251 let neg_branch = self.try_mul(slope)?;
252 self.where_(&condition, &neg_branch)
253 }
254
255 /// Thresholded ReLU: `x if x > alpha, 0 otherwise`.
256 ///
257 /// # Arguments
258 /// * `alpha` - Threshold (default 1.0 in ONNX)
259 pub fn thresholded_relu(&self, alpha: f64) -> Result<Self> {
260 let alpha_t = self.broadcast_scalar(ConstValue::Float(alpha))?;
261 let zero = self.zero()?;
262 let condition = self.try_gt(&alpha_t)?;
263 self.where_(&condition, &zero)
264 }
265
266 /// ELU: `x if x > 0, alpha * (exp(x) - 1) otherwise`.
267 ///
268 /// # Arguments
269 /// * `alpha` - Scale for negative part (default 1.0 in ONNX)
270 pub fn elu(&self, alpha: f64) -> Result<Self> {
271 let zero = self.zero()?;
272 let one = self.one()?;
273 let alpha_t = self.broadcast_scalar(ConstValue::Float(alpha))?;
274 let condition = self.try_gt(&zero)?;
275 let exp_minus_1 = self.try_exp()?.try_sub(&one)?;
276 let neg_branch = alpha_t.try_mul(&exp_minus_1)?;
277 self.where_(&condition, &neg_branch)
278 }
279
280 /// SELU: `gamma * (alpha * exp(x) - alpha) if x <= 0, gamma * x if x > 0`.
281 ///
282 /// Self-normalizing activation with fixed constants.
283 ///
284 /// # Arguments
285 /// * `alpha` - Default 1.6732632...
286 /// * `gamma` - Default 1.0507010...
287 pub fn selu(&self, alpha: f64, gamma: f64) -> Result<Self> {
288 let zero = self.zero()?;
289 let alpha_t = self.broadcast_scalar(ConstValue::Float(alpha))?;
290 let gamma_t = self.broadcast_scalar(ConstValue::Float(gamma))?;
291 let condition = self.try_ge(&zero)?;
292 // neg: alpha * exp(x) - alpha
293 let neg_branch =
294 alpha_t.try_mul(&self.try_exp()?)?.try_sub(&self.broadcast_scalar(ConstValue::Float(alpha))?)?;
295 let selected = self.where_(&condition, &neg_branch)?;
296 gamma_t.try_mul(&selected)
297 }
298
299 /// Swish/SiLU activation: `x * sigmoid(x)`.
300 ///
301 /// Also known as SiLU (Sigmoid Linear Unit).
302 /// Used in modern CNN architectures and some Transformers.
303 ///
304 /// # Examples
305 /// ```ignore
306 /// let x = Tensor::from_slice(&[-2.0f32, -1.0, 0.0, 1.0, 2.0]);
307 /// let y = x.swish()?;
308 /// ```
309 pub fn swish(&self) -> Result<Self> {
310 // swish(x) = x * sigmoid(x)
311 let sig = self.sigmoid()?;
312 self.try_mul(&sig)
313 }
314
315 /// Alias for `swish` (matches PyTorch naming).
316 pub fn silu(&self) -> Result<Self> {
317 self.swish()
318 }
319
320 /// Gated Linear Unit: splits `self` along `dim` into two halves,
321 /// returns `first_half * sigmoid(second_half)`.
322 pub fn glu(&self, dim: isize) -> Result<Self> {
323 let shape = self.shape()?;
324 let ndim = shape.len();
325 let axis = if dim < 0 { (ndim as isize + dim) as usize } else { dim as usize };
326 let full_size = shape[axis].as_const().expect("GLU dim must be concrete");
327 assert!(full_size % 2 == 0, "GLU dimension must be even, got {full_size}");
328 let half = full_size / 2;
329 let halves = self.split(&[half, half], dim)?;
330 let gate = halves[1].sigmoid()?;
331 halves[0].try_mul(&gate)
332 }
333
334 /// Softplus: `log(1 + exp(beta*x)) / beta`, numerically stable via logaddexp.
335 pub fn softplus(&self, beta: f64) -> Result<Self> {
336 let beta_t = self.broadcast_scalar(ConstValue::Float(beta))?;
337 let scaled = self.try_mul(&beta_t)?;
338 let zero = self.zero()?;
339 let inv_beta = self.broadcast_scalar(ConstValue::Float(1.0 / beta))?;
340 // logaddexp(scaled, 0) = max(scaled, 0) + log(exp(scaled - max) + exp(0 - max))
341 let m = scaled.maximum(&zero)?;
342 let stable = scaled.try_sub(&m)?.try_exp()?.try_add(&zero.try_sub(&m)?.try_exp()?)?.try_log()?.try_add(&m)?;
343 stable.try_mul(&inv_beta)
344 }
345
346 /// Mish: `x * tanh(softplus(x))`.
347 pub fn mish(&self) -> Result<Self> {
348 self.try_mul(&self.softplus(1.0)?.tanh()?)
349 }
350
351 /// ReLU6: `relu(x) - relu(x-6)` = `clamp(x, 0, 6)`.
352 pub fn relu6(&self) -> Result<Self> {
353 let six = self.broadcast_scalar(ConstValue::Int(6))?;
354 let relu_x = self.relu()?;
355 let relu_x6 = self.try_sub(&six)?.relu()?;
356 relu_x.try_sub(&relu_x6)
357 }
358
359 /// HardSwish: `x * relu6(x+3) / 6`.
360 pub fn hardswish(&self) -> Result<Self> {
361 let three = self.broadcast_scalar(ConstValue::Int(3))?;
362 let six = self.broadcast_scalar(ConstValue::Int(6))?;
363 let r6 = self.try_add(&three)?.relu6()?;
364 self.try_mul(&r6)?.try_div(&six)
365 }
366
367 /// Softsign: `x / (1 + |x|)`.
368 pub fn softsign(&self) -> Result<Self> {
369 let one = self.one()?;
370 let denom = one.try_add(&self.try_abs()?)?;
371 self.try_div(&denom)
372 }
373
374 /// CELU: `max(0, x) + min(0, alpha*(exp(x/alpha)-1))`.
375 pub fn celu(&self, alpha: f64) -> Result<Self> {
376 let zero = self.zero()?;
377 let one = self.one()?;
378 let alpha_t = self.broadcast_scalar(ConstValue::Float(alpha))?;
379 let pos = self.maximum(&zero)?;
380 let neg = alpha_t.try_mul(&self.try_div(&alpha_t)?.try_exp()?.try_sub(&one)?)?.minimum(&zero)?;
381 pos.try_add(&neg)
382 }
383
384 /// Batch Normalization.
385 ///
386 /// Applies: `y = scale * (x - mean) * invstd + bias`
387 /// where `invstd = 1 / sqrt(var + epsilon)`
388 ///
389 /// This is the inference mode batchnorm (no running stats update).
390 /// The caller provides pre-computed mean and inverse standard deviation.
391 ///
392 /// # Arguments
393 /// * `scale` - Gamma/weight parameter (optional, defaults to 1)
394 /// * `bias` - Beta parameter (optional, defaults to 0)
395 /// * `mean` - Running mean
396 /// * `invstd` - Inverse standard deviation (1 / sqrt(var + eps))
397 /// * `axis` - Axis/axes to normalize over (default: 1 for NCHW)
398 ///
399 /// # Examples
400 /// ```ignore
401 /// let x = Tensor::randn(&[8, 4, 16, 16]);
402 /// let mean = x.mean(AxisSpec::Multiple(vec![0, 2, 3]))?;
403 /// let var = x.var(AxisSpec::Multiple(vec![0, 2, 3]))?;
404 /// let eps = Tensor::from_slice([1e-5]);
405 /// let invstd = var.try_add(&eps)?.try_rsqrt()?;
406 /// let normalized = x.batchnorm().mean(&mean).invstd(&invstd).call()?;
407 /// ```
408 #[builder]
409 pub fn batchnorm(
410 &self,
411 scale: Option<&Tensor>,
412 bias: Option<&Tensor>,
413 mean: &Tensor,
414 invstd: &Tensor,
415 #[builder(default = AxisSpec::Single(1))] axis: AxisSpec,
416 ) -> Result<Self> {
417 let shape = self.shape()?;
418
419 // Build broadcast shape: keep axis dimensions, others become 1
420 let axis_indices: std::collections::HashSet<usize> = match &axis {
421 AxisSpec::All => (0..shape.len()).collect(),
422 AxisSpec::Single(a) => {
423 let a = if *a < 0 { (shape.len() as isize + *a) as usize } else { *a as usize };
424 std::iter::once(a).collect()
425 }
426 AxisSpec::Multiple(axes) => {
427 axes.iter().map(|&a| if a < 0 { (shape.len() as isize + a) as usize } else { a as usize }).collect()
428 }
429 };
430
431 let broadcast_shape: Vec<isize> = shape
432 .iter()
433 .enumerate()
434 .map(|(i, dim)| if axis_indices.contains(&i) { dim.as_const().unwrap_or(1) as isize } else { 1 })
435 .collect();
436
437 // x - mean (reshape mean to broadcast shape, like Tinygrad does)
438 let mean_reshaped = mean.try_reshape(&broadcast_shape)?;
439 let x_centered = self.try_sub(&mean_reshaped)?;
440
441 // (x - mean) * invstd
442 let invstd_reshaped = invstd.try_reshape(&broadcast_shape)?;
443 let mut result = x_centered.try_mul(&invstd_reshaped)?;
444
445 // scale * (x - mean) * invstd
446 if let Some(w) = scale {
447 let w_reshaped = w.try_reshape(&broadcast_shape)?;
448 result = result.try_mul(&w_reshaped)?;
449 }
450
451 // scale * (x - mean) * invstd + bias
452 if let Some(b) = bias {
453 let b_reshaped = b.try_reshape(&broadcast_shape)?;
454 result = result.try_add(&b_reshaped)?;
455 }
456
457 Ok(result)
458 }
459}