Skip to main content

trustformers_core/tensor/
activations.rs

1//! Tensor activation functions.
2//!
3//! This module contains activation functions commonly used in neural networks.
4//!
5//! # Performance
6//!
7//! This module uses scirs2-core's SIMD-optimized activation functions for larger tensors:
8//! - `simd_gelu` - GELU activation (used in BERT, GPT, etc.)
9//! - `simd_swish` - Swish/SiLU activation (used in EfficientNet, GPT-NeoX)
10//! - `simd_sigmoid` - Sigmoid activation
11//! - `simd_tanh` - Tanh activation
12//!
13//! For tensors with <256 elements, uses scalar operations to avoid SIMD overhead.
14
15#![allow(deprecated)] // Using rand legacy API, will migrate to scirs2_core
16
17use super::Tensor;
18use crate::errors::{Result, TrustformersError};
19use scirs2_core::ndarray::{Axis, IxDyn};
20use scirs2_core::simd_ops::SimdUnifiedOps;
21
22/// Minimum tensor size to use SIMD operations (avoids overhead for small tensors)
23const MIN_SIZE_FOR_SIMD: usize = 256;
24
25impl Tensor {
26    /// ReLU activation function.
27    ///
28    /// # Returns
29    ///
30    /// A tensor with ReLU applied element-wise.
31    pub fn relu(&self) -> Result<Tensor> {
32        match self {
33            Tensor::F32(a) => {
34                let result = a.mapv(|x| x.max(0.0));
35                Ok(Tensor::F32(result))
36            },
37            Tensor::F64(a) => {
38                let result = a.mapv(|x| x.max(0.0));
39                Ok(Tensor::F64(result))
40            },
41            _ => Err(TrustformersError::tensor_op_error(
42                "ReLU not supported for this tensor type",
43                "relu",
44            )),
45        }
46    }
47
48    /// Sigmoid activation function.
49    ///
50    /// # Performance
51    ///
52    /// Uses scirs2-core's SIMD-accelerated sigmoid for tensors with ≥256 elements.
53    ///
54    /// # Returns
55    ///
56    /// A tensor with sigmoid applied element-wise.
57    pub fn sigmoid(&self) -> Result<Tensor> {
58        match self {
59            Tensor::F32(a) => {
60                let size = a.len();
61                if size >= MIN_SIZE_FOR_SIMD {
62                    // Use SIMD-accelerated sigmoid for larger tensors
63                    let shape = a.shape().to_vec();
64                    let flat = a.as_standard_layout();
65                    let flat_view = flat
66                        .view()
67                        .into_shape_with_order(size)
68                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
69                    let result_1d = f32::simd_sigmoid(&flat_view);
70                    let result = result_1d
71                        .into_shape_with_order(IxDyn(&shape))
72                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
73                    Ok(Tensor::F32(result))
74                } else {
75                    // Numerically stable sigmoid implementation for small tensors
76                    let result = a.mapv(|x| {
77                        if x >= 0.0 {
78                            let exp_neg_x = (-x).exp();
79                            1.0 / (1.0 + exp_neg_x)
80                        } else {
81                            let exp_x = x.exp();
82                            exp_x / (1.0 + exp_x)
83                        }
84                    });
85                    Ok(Tensor::F32(result))
86                }
87            },
88            Tensor::F64(a) => {
89                let size = a.len();
90                if size >= MIN_SIZE_FOR_SIMD {
91                    // Use SIMD-accelerated sigmoid for larger tensors
92                    let shape = a.shape().to_vec();
93                    let flat = a.as_standard_layout();
94                    let flat_view = flat
95                        .view()
96                        .into_shape_with_order(size)
97                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
98                    let result_1d = f64::simd_sigmoid(&flat_view);
99                    let result = result_1d
100                        .into_shape_with_order(IxDyn(&shape))
101                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
102                    Ok(Tensor::F64(result))
103                } else {
104                    // Numerically stable sigmoid implementation for small tensors
105                    let result = a.mapv(|x| {
106                        if x >= 0.0 {
107                            let exp_neg_x = (-x).exp();
108                            1.0 / (1.0 + exp_neg_x)
109                        } else {
110                            let exp_x = x.exp();
111                            exp_x / (1.0 + exp_x)
112                        }
113                    });
114                    Ok(Tensor::F64(result))
115                }
116            },
117            _ => Err(TrustformersError::tensor_op_error(
118                "Sigmoid not supported for this tensor type",
119                "sigmoid",
120            )),
121        }
122    }
123
124    /// Tanh activation function.
125    ///
126    /// # Performance
127    ///
128    /// Uses scirs2-core's SIMD-accelerated tanh for tensors with ≥256 elements.
129    ///
130    /// # Returns
131    ///
132    /// A tensor with tanh applied element-wise.
133    pub fn tanh(&self) -> Result<Tensor> {
134        match self {
135            Tensor::F32(a) => {
136                let size = a.len();
137                if size >= MIN_SIZE_FOR_SIMD {
138                    // Use SIMD-accelerated tanh for larger tensors
139                    let shape = a.shape().to_vec();
140                    let flat = a.as_standard_layout();
141                    let flat_view = flat
142                        .view()
143                        .into_shape_with_order(size)
144                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
145                    let result_1d = f32::simd_tanh(&flat_view);
146                    let result = result_1d
147                        .into_shape_with_order(IxDyn(&shape))
148                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
149                    Ok(Tensor::F32(result))
150                } else {
151                    let result = a.mapv(|x| x.tanh());
152                    Ok(Tensor::F32(result))
153                }
154            },
155            Tensor::F64(a) => {
156                let size = a.len();
157                if size >= MIN_SIZE_FOR_SIMD {
158                    // Use SIMD-accelerated tanh for larger tensors
159                    let shape = a.shape().to_vec();
160                    let flat = a.as_standard_layout();
161                    let flat_view = flat
162                        .view()
163                        .into_shape_with_order(size)
164                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
165                    let result_1d = f64::simd_tanh(&flat_view);
166                    let result = result_1d
167                        .into_shape_with_order(IxDyn(&shape))
168                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
169                    Ok(Tensor::F64(result))
170                } else {
171                    let result = a.mapv(|x| x.tanh());
172                    Ok(Tensor::F64(result))
173                }
174            },
175            _ => Err(TrustformersError::tensor_op_error(
176                "Tanh not supported for this tensor type",
177                "tanh",
178            )),
179        }
180    }
181
182    /// Softmax activation function.
183    ///
184    /// # Arguments
185    ///
186    /// * `axis` - The axis along which to apply softmax
187    ///
188    /// # Returns
189    ///
190    /// A tensor with softmax applied along the specified axis.
191    pub fn softmax(&self, axis: i32) -> Result<Tensor> {
192        match self {
193            Tensor::F32(a) => {
194                let ndim = a.ndim();
195                let axis = if axis < 0 { (ndim as i32 + axis) as usize } else { axis as usize };
196
197                if axis >= ndim {
198                    return Err(TrustformersError::shape_error(format!(
199                        "Axis {} is out of bounds for tensor with {} dimensions",
200                        axis, ndim
201                    )));
202                }
203
204                // Ensure contiguous input layout
205                let a_contiguous = a.as_standard_layout().to_owned();
206
207                // For numerical stability, subtract max before exp
208                let max_vals = a_contiguous.map_axis(Axis(axis), |lane| {
209                    lane.iter().fold(f32::NEG_INFINITY, |acc, &x| acc.max(x))
210                });
211
212                // Ensure contiguous max_vals and compute shifted values
213                let max_vals_contiguous = max_vals.as_standard_layout().to_owned();
214                let shifted = &a_contiguous - &max_vals_contiguous.insert_axis(Axis(axis));
215                let shifted_contiguous = shifted.as_standard_layout().to_owned();
216
217                // Compute exp and sum with contiguous layout
218                let exp_vals = shifted_contiguous.mapv(|x| x.exp());
219                let exp_vals_contiguous = exp_vals.as_standard_layout().to_owned();
220                let sum_exp = exp_vals_contiguous.sum_axis(Axis(axis));
221                let sum_exp_contiguous = sum_exp.as_standard_layout().to_owned();
222
223                // Protect against division by very small numbers
224                let protected_sum = sum_exp_contiguous.mapv(|x| {
225                    if x <= f32::MIN_POSITIVE {
226                        f32::MIN_POSITIVE
227                    } else {
228                        x
229                    }
230                });
231
232                // Final result with contiguous layout
233                let result = exp_vals_contiguous / protected_sum.insert_axis(Axis(axis));
234                let result_contiguous = result.as_standard_layout().to_owned();
235                Ok(Tensor::F32(result_contiguous))
236            },
237            Tensor::F64(a) => {
238                let ndim = a.ndim();
239                let axis = if axis < 0 { (ndim as i32 + axis) as usize } else { axis as usize };
240
241                if axis >= ndim {
242                    return Err(TrustformersError::shape_error(format!(
243                        "Axis {} is out of bounds for tensor with {} dimensions",
244                        axis, ndim
245                    )));
246                }
247
248                // Ensure contiguous input layout
249                let a_contiguous = a.as_standard_layout().to_owned();
250
251                let max_vals = a_contiguous.map_axis(Axis(axis), |lane| {
252                    lane.iter().fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
253                });
254
255                // Ensure contiguous layouts throughout computation
256                let max_vals_contiguous = max_vals.as_standard_layout().to_owned();
257                let shifted = &a_contiguous - &max_vals_contiguous.insert_axis(Axis(axis));
258                let shifted_contiguous = shifted.as_standard_layout().to_owned();
259
260                let exp_vals = shifted_contiguous.mapv(|x| x.exp());
261                let exp_vals_contiguous = exp_vals.as_standard_layout().to_owned();
262                let sum_exp = exp_vals_contiguous.sum_axis(Axis(axis));
263                let sum_exp_contiguous = sum_exp.as_standard_layout().to_owned();
264
265                // Protect against division by very small numbers
266                let protected_sum = sum_exp_contiguous.mapv(|x| {
267                    if x <= f64::MIN_POSITIVE {
268                        f64::MIN_POSITIVE
269                    } else {
270                        x
271                    }
272                });
273
274                let result = exp_vals_contiguous / protected_sum.insert_axis(Axis(axis));
275                let result_contiguous = result.as_standard_layout().to_owned();
276                Ok(Tensor::F64(result_contiguous))
277            },
278            _ => Err(TrustformersError::tensor_op_error(
279                "Softmax not supported for this tensor type",
280                "softmax",
281            )),
282        }
283    }
284
285    /// Dropout operation.
286    ///
287    /// # Arguments
288    ///
289    /// * `dropout_prob` - Probability of dropping each element
290    ///
291    /// # Returns
292    ///
293    /// A tensor with dropout applied.
294    pub fn dropout(&self, dropout_prob: f32) -> Result<Tensor> {
295        use scirs2_core::random::*;
296
297        if !(0.0..=1.0).contains(&dropout_prob) {
298            return Err(TrustformersError::tensor_op_error(
299                "Dropout probability must be between 0 and 1",
300                "dropout",
301            ));
302        }
303
304        if dropout_prob == 0.0 {
305            return Ok(self.clone());
306        }
307
308        match self {
309            Tensor::F32(a) => {
310                let mut rng = thread_rng();
311                let scale = 1.0 / (1.0 - dropout_prob);
312                let result =
313                    a.mapv(
314                        |x| {
315                            if rng.random::<f32>() < dropout_prob {
316                                0.0
317                            } else {
318                                x * scale
319                            }
320                        },
321                    );
322                Ok(Tensor::F32(result))
323            },
324            _ => Err(TrustformersError::tensor_op_error(
325                "Dropout not supported for this tensor type",
326                "dropout",
327            )),
328        }
329    }
330
331    /// GELU (Gaussian Error Linear Unit) activation function.
332    ///
333    /// # Performance
334    ///
335    /// Uses scirs2-core's SIMD-accelerated GELU for tensors with ≥256 elements.
336    /// GELU is widely used in Transformer models (BERT, GPT, etc.).
337    ///
338    /// # Returns
339    ///
340    /// A tensor with GELU applied element-wise.
341    pub fn gelu(&self) -> Result<Tensor> {
342        match self {
343            // Metal GPU path - stays on GPU!
344            #[cfg(all(target_os = "macos", feature = "metal"))]
345            Tensor::Metal(metal_data) => {
346                use crate::gpu_ops::metal::get_metal_backend;
347                use crate::tensor::MetalTensorData;
348
349                let backend = get_metal_backend()?;
350                let size = metal_data.shape.iter().product();
351
352                let output_buffer_id = backend.gelu_gpu_to_gpu(&metal_data.buffer_id, size)?;
353
354                Ok(Tensor::Metal(MetalTensorData {
355                    buffer_id: output_buffer_id,
356                    shape: metal_data.shape.clone(),
357                    dtype: metal_data.dtype,
358                }))
359            },
360            Tensor::F32(a) => {
361                let size = a.len();
362                if size >= MIN_SIZE_FOR_SIMD {
363                    // Use SIMD-accelerated GELU for larger tensors
364                    let shape = a.shape().to_vec();
365                    let flat = a.as_standard_layout();
366                    let flat_view = flat
367                        .view()
368                        .into_shape_with_order(size)
369                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
370                    let result_1d = f32::simd_gelu(&flat_view);
371                    let result = result_1d
372                        .into_shape_with_order(IxDyn(&shape))
373                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
374                    Ok(Tensor::F32(result))
375                } else {
376                    // Scalar path for small tensors
377                    let result = a.mapv(|x| {
378                        0.5 * x * (1.0 + (0.7978845608 * (x + 0.044715 * x.powi(3))).tanh())
379                    });
380                    Ok(Tensor::F32(result))
381                }
382            },
383            Tensor::F64(a) => {
384                let size = a.len();
385                if size >= MIN_SIZE_FOR_SIMD {
386                    // Use SIMD-accelerated GELU for larger tensors
387                    let shape = a.shape().to_vec();
388                    let flat = a.as_standard_layout();
389                    let flat_view = flat
390                        .view()
391                        .into_shape_with_order(size)
392                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
393                    let result_1d = f64::simd_gelu(&flat_view);
394                    let result = result_1d
395                        .into_shape_with_order(IxDyn(&shape))
396                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
397                    Ok(Tensor::F64(result))
398                } else {
399                    // Scalar path for small tensors
400                    let result = a.mapv(|x| {
401                        0.5 * x * (1.0 + (0.7978845608028654 * (x + 0.044715 * x.powi(3))).tanh())
402                    });
403                    Ok(Tensor::F64(result))
404                }
405            },
406            _ => Err(TrustformersError::tensor_op_error(
407                "GELU not supported for this tensor type",
408                "gelu",
409            )),
410        }
411    }
412
413    /// Leaky ReLU activation function.
414    ///
415    /// # Arguments
416    ///
417    /// * `negative_slope` - The slope for negative values (default: 0.01)
418    ///
419    /// # Returns
420    ///
421    /// A tensor with Leaky ReLU applied element-wise.
422    pub fn leaky_relu(&self, negative_slope: f32) -> Result<Tensor> {
423        match self {
424            Tensor::F32(a) => {
425                let result = a.mapv(|x| if x > 0.0 { x } else { negative_slope * x });
426                Ok(Tensor::F32(result))
427            },
428            Tensor::F64(a) => {
429                let negative_slope = negative_slope as f64;
430                let result = a.mapv(|x| if x > 0.0 { x } else { negative_slope * x });
431                Ok(Tensor::F64(result))
432            },
433            _ => Err(TrustformersError::tensor_op_error(
434                "Leaky ReLU not supported for this tensor type",
435                "leaky_relu",
436            )),
437        }
438    }
439
440    /// SiLU (Sigmoid-Linear Unit) activation function.
441    ///
442    /// Also known as Swish activation: f(x) = x * sigmoid(x)
443    ///
444    /// # Performance
445    ///
446    /// Uses scirs2-core's SIMD-accelerated Swish for tensors with ≥256 elements.
447    /// SiLU/Swish is used in EfficientNet, GPT-NeoX, and many modern architectures.
448    ///
449    /// # Returns
450    ///
451    /// A tensor with SiLU applied element-wise.
452    pub fn silu(&self) -> Result<Tensor> {
453        match self {
454            Tensor::F32(a) => {
455                let size = a.len();
456                if size >= MIN_SIZE_FOR_SIMD {
457                    // Use SIMD-accelerated Swish for larger tensors
458                    let shape = a.shape().to_vec();
459                    let flat = a.as_standard_layout();
460                    let flat_view = flat
461                        .view()
462                        .into_shape_with_order(size)
463                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
464                    let result_1d = f32::simd_swish(&flat_view);
465                    let result = result_1d
466                        .into_shape_with_order(IxDyn(&shape))
467                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
468                    Ok(Tensor::F32(result))
469                } else {
470                    // Scalar path for small tensors
471                    let result = a.mapv(|x| x * (1.0 / (1.0 + (-x).exp())));
472                    Ok(Tensor::F32(result))
473                }
474            },
475            Tensor::F64(a) => {
476                let size = a.len();
477                if size >= MIN_SIZE_FOR_SIMD {
478                    // Use SIMD-accelerated Swish for larger tensors
479                    let shape = a.shape().to_vec();
480                    let flat = a.as_standard_layout();
481                    let flat_view = flat
482                        .view()
483                        .into_shape_with_order(size)
484                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
485                    let result_1d = f64::simd_swish(&flat_view);
486                    let result = result_1d
487                        .into_shape_with_order(IxDyn(&shape))
488                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
489                    Ok(Tensor::F64(result))
490                } else {
491                    // Scalar path for small tensors
492                    let result = a.mapv(|x| x * (1.0 / (1.0 + (-x).exp())));
493                    Ok(Tensor::F64(result))
494                }
495            },
496            _ => Err(TrustformersError::tensor_op_error(
497                "SiLU not supported for this tensor type",
498                "silu",
499            )),
500        }
501    }
502
503    /// Swish activation function (alias for SiLU).
504    ///
505    /// Swish(x) = x * sigmoid(x) = SiLU(x)
506    pub fn swish(&self) -> Result<Tensor> {
507        self.silu()
508    }
509}