1use scirs2_core::ndarray::{ArrayD, Zip};
8
9#[derive(Debug, Clone)]
15pub enum ActivationError {
16 EmptyInput,
18 InvalidParameter {
20 name: String,
21 value: f64,
22 reason: String,
23 },
24 ShapeMismatch {
26 expected: Vec<usize>,
27 got: Vec<usize>,
28 },
29}
30
31impl std::fmt::Display for ActivationError {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 Self::EmptyInput => write!(f, "activation: input tensor is empty"),
35 Self::InvalidParameter {
36 name,
37 value,
38 reason,
39 } => {
40 write!(
41 f,
42 "activation: invalid parameter '{name}' = {value}: {reason}"
43 )
44 }
45 Self::ShapeMismatch { expected, got } => {
46 write!(
47 f,
48 "activation: shape mismatch — expected {expected:?}, got {got:?}"
49 )
50 }
51 }
52 }
53}
54
55impl std::error::Error for ActivationError {}
56
57#[inline]
64fn erf_approx(x: f64) -> f64 {
65 const A1: f64 = 0.278_393;
66 const A2: f64 = 0.230_389;
67 const A3: f64 = 0.000_972;
68 const A4: f64 = 0.078_108;
69 let sign = x.signum();
70 let x = x.abs();
71 let t = 1.0 / (1.0 + 0.47047 * x);
72 let poly = ((A4 * t + A3) * t + A2) * t + A1;
73 let result = 1.0 - poly * t * (-x * x).exp();
74 sign * result
75}
76
77#[inline]
78fn sigmoid_scalar_impl(x: f64) -> f64 {
79 1.0 / (1.0 + (-x).exp())
80}
81
82#[inline]
83fn softplus_scalar(x: f64, beta: f64) -> f64 {
84 let bx = beta * x;
86 if bx > 30.0 {
87 x
88 } else {
89 (1.0 + bx.exp()).ln() / beta
90 }
91}
92
93#[inline]
99pub fn relu_scalar(x: f64) -> f64 {
100 x.max(0.0)
101}
102
103#[inline]
105pub fn gelu_scalar(x: f64) -> f64 {
106 x * 0.5 * (1.0 + erf_approx(x / std::f64::consts::SQRT_2))
107}
108
109#[inline]
111pub fn swish_scalar(x: f64) -> f64 {
112 x * sigmoid_scalar_impl(x)
113}
114
115#[inline]
117pub fn sigmoid_scalar(x: f64) -> f64 {
118 sigmoid_scalar_impl(x)
119}
120
121pub fn relu(input: &ArrayD<f64>) -> ArrayD<f64> {
127 input.mapv(relu_scalar)
128}
129
130pub fn relu6(input: &ArrayD<f64>) -> ArrayD<f64> {
132 input.mapv(|x| x.clamp(0.0, 6.0))
133}
134
135pub fn leaky_relu(input: &ArrayD<f64>, negative_slope: f64) -> ArrayD<f64> {
137 input.mapv(|x| if x >= 0.0 { x } else { negative_slope * x })
138}
139
140pub fn elu(input: &ArrayD<f64>, alpha: f64) -> Result<ArrayD<f64>, ActivationError> {
144 if alpha < 0.0 {
145 return Err(ActivationError::InvalidParameter {
146 name: "alpha".into(),
147 value: alpha,
148 reason: "alpha must be non-negative for ELU".into(),
149 });
150 }
151 Ok(input.mapv(|x| if x >= 0.0 { x } else { alpha * (x.exp() - 1.0) }))
152}
153
154pub fn selu(input: &ArrayD<f64>) -> ArrayD<f64> {
158 const ALPHA: f64 = 1.673_263_242_354_377_2;
159 const SCALE: f64 = 1.050_700_987_355_480_5;
160 input.mapv(|x| SCALE * if x >= 0.0 { x } else { ALPHA * (x.exp() - 1.0) })
161}
162
163pub fn gelu(input: &ArrayD<f64>) -> ArrayD<f64> {
165 input.mapv(gelu_scalar)
166}
167
168pub fn gelu_approx(input: &ArrayD<f64>) -> ArrayD<f64> {
171 const C: f64 = 0.797_884_560_802_865_4; input.mapv(|x| {
173 let inner = C * (x + 0.044_715 * x * x * x);
174 0.5 * x * (1.0 + inner.tanh())
175 })
176}
177
178pub fn swish(input: &ArrayD<f64>) -> ArrayD<f64> {
180 input.mapv(swish_scalar)
181}
182
183pub fn silu(input: &ArrayD<f64>) -> ArrayD<f64> {
185 swish(input)
186}
187
188pub fn mish(input: &ArrayD<f64>) -> ArrayD<f64> {
190 input.mapv(|x| {
191 let sp = softplus_scalar(x, 1.0);
192 x * sp.tanh()
193 })
194}
195
196pub fn softplus(input: &ArrayD<f64>, beta: f64) -> Result<ArrayD<f64>, ActivationError> {
200 if beta <= 0.0 {
201 return Err(ActivationError::InvalidParameter {
202 name: "beta".into(),
203 value: beta,
204 reason: "beta must be positive for Softplus".into(),
205 });
206 }
207 Ok(input.mapv(|x| softplus_scalar(x, beta)))
208}
209
210pub fn softsign(input: &ArrayD<f64>) -> ArrayD<f64> {
212 input.mapv(|x| x / (1.0 + x.abs()))
213}
214
215pub fn hardswish(input: &ArrayD<f64>) -> ArrayD<f64> {
217 input.mapv(|x| x * (x + 3.0).clamp(0.0, 6.0) / 6.0)
218}
219
220pub fn hardsigmoid(input: &ArrayD<f64>) -> ArrayD<f64> {
222 input.mapv(|x| (x + 3.0).clamp(0.0, 6.0) / 6.0)
223}
224
225pub fn sigmoid(input: &ArrayD<f64>) -> ArrayD<f64> {
227 input.mapv(sigmoid_scalar_impl)
228}
229
230pub fn tanh_activation(input: &ArrayD<f64>) -> ArrayD<f64> {
232 input.mapv(|x| x.tanh())
233}
234
235pub fn prelu(input: &ArrayD<f64>, weights: &ArrayD<f64>) -> Result<ArrayD<f64>, ActivationError> {
241 if input.is_empty() {
242 return Err(ActivationError::EmptyInput);
243 }
244
245 let channels = if input.ndim() == 0 {
247 1
248 } else {
249 input.shape()[0]
250 };
251 let w_len = weights.len();
252
253 if w_len != channels && w_len != 1 {
254 return Err(ActivationError::ShapeMismatch {
255 expected: vec![channels],
256 got: weights.shape().to_vec(),
257 });
258 }
259
260 let weights_flat: Vec<f64> = weights.iter().copied().collect();
261 let get_w = |ch: usize| -> f64 {
262 if w_len == 1 {
263 weights_flat[0]
264 } else {
265 weights_flat[ch]
266 }
267 };
268
269 if input.ndim() <= 1 {
270 let out: Vec<f64> = input
272 .iter()
273 .enumerate()
274 .map(|(i, &x)| {
275 let ch = if w_len == 1 { 0 } else { i };
276 if x >= 0.0 {
277 x
278 } else {
279 get_w(ch) * x
280 }
281 })
282 .collect();
283 return Ok(ArrayD::from_shape_vec(input.raw_dim(), out)
284 .unwrap_or_else(|_| input.mapv(relu_scalar)));
285 }
286
287 let shape = input.shape().to_vec();
289 let mut result = input.clone();
290 let stride: usize = shape[1..].iter().product();
291
292 for (idx, val) in result.iter_mut().enumerate() {
293 let ch = (idx / stride) % channels;
294 if *val < 0.0 {
295 *val *= get_w(ch);
296 }
297 }
298 Ok(result)
299}
300
301pub fn softmax(input: &ArrayD<f64>, axis: usize) -> Result<ArrayD<f64>, ActivationError> {
307 if input.is_empty() {
308 return Err(ActivationError::EmptyInput);
309 }
310 if axis >= input.ndim() {
311 return Err(ActivationError::InvalidParameter {
312 name: "axis".into(),
313 value: axis as f64,
314 reason: format!("axis {} out of range for ndim {}", axis, input.ndim()),
315 });
316 }
317
318 let max_vals = input.map_axis(scirs2_core::ndarray::Axis(axis), |lane| {
320 lane.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
321 });
322
323 let mut shifted = input.clone();
324 Zip::from(&mut shifted)
326 .and_broadcast(&max_vals.insert_axis(scirs2_core::ndarray::Axis(axis)))
327 .for_each(|s, &m| *s -= m);
328
329 let mut exped = shifted.mapv(f64::exp);
330
331 let sum_vals = exped.map_axis(scirs2_core::ndarray::Axis(axis), |lane| {
332 lane.iter().cloned().sum::<f64>()
333 });
334
335 Zip::from(&mut exped)
336 .and_broadcast(&sum_vals.insert_axis(scirs2_core::ndarray::Axis(axis)))
337 .for_each(|e, &s| *e /= s);
338
339 Ok(exped)
340}
341
342pub fn log_softmax(input: &ArrayD<f64>, axis: usize) -> Result<ArrayD<f64>, ActivationError> {
344 if input.is_empty() {
345 return Err(ActivationError::EmptyInput);
346 }
347 if axis >= input.ndim() {
348 return Err(ActivationError::InvalidParameter {
349 name: "axis".into(),
350 value: axis as f64,
351 reason: format!("axis {} out of range for ndim {}", axis, input.ndim()),
352 });
353 }
354
355 let max_vals = input.map_axis(scirs2_core::ndarray::Axis(axis), |lane| {
356 lane.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
357 });
358
359 let mut shifted = input.clone();
360 Zip::from(&mut shifted)
361 .and_broadcast(&max_vals.insert_axis(scirs2_core::ndarray::Axis(axis)))
362 .for_each(|s, &m| *s -= m);
363
364 let log_sum_exp = shifted
365 .mapv(f64::exp)
366 .map_axis(scirs2_core::ndarray::Axis(axis), |lane| {
367 lane.iter().cloned().sum::<f64>().ln()
368 });
369
370 Zip::from(&mut shifted)
371 .and_broadcast(&log_sum_exp.insert_axis(scirs2_core::ndarray::Axis(axis)))
372 .for_each(|s, &lse| *s -= lse);
373
374 Ok(shifted)
375}
376
377pub fn relu_grad(input: &ArrayD<f64>, grad_output: &ArrayD<f64>) -> ArrayD<f64> {
383 let mut out = grad_output.clone();
384 Zip::from(&mut out).and(input).for_each(|g, &x| {
385 if x <= 0.0 {
386 *g = 0.0;
387 }
388 });
389 out
390}
391
392pub fn sigmoid_grad(output: &ArrayD<f64>, grad_output: &ArrayD<f64>) -> ArrayD<f64> {
396 let mut out = grad_output.clone();
397 Zip::from(&mut out)
398 .and(output)
399 .for_each(|g, &s| *g *= s * (1.0 - s));
400 out
401}
402
403pub fn tanh_grad(output: &ArrayD<f64>, grad_output: &ArrayD<f64>) -> ArrayD<f64> {
407 let mut out = grad_output.clone();
408 Zip::from(&mut out)
409 .and(output)
410 .for_each(|g, &t| *g *= 1.0 - t * t);
411 out
412}
413
414#[derive(Debug, Clone, PartialEq)]
420pub enum ActivationType {
421 Relu,
422 Relu6,
423 LeakyRelu(f64),
424 Elu(f64),
425 Selu,
426 Gelu,
427 GeluApprox,
428 Swish,
429 Mish,
430 Softplus(f64),
431 Softsign,
432 Hardswish,
433 Hardsigmoid,
434 Sigmoid,
435 Tanh,
436}
437
438impl ActivationType {
439 pub fn apply(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>, ActivationError> {
441 match self {
442 Self::Relu => Ok(relu(input)),
443 Self::Relu6 => Ok(relu6(input)),
444 Self::LeakyRelu(s) => Ok(leaky_relu(input, *s)),
445 Self::Elu(a) => elu(input, *a),
446 Self::Selu => Ok(selu(input)),
447 Self::Gelu => Ok(gelu(input)),
448 Self::GeluApprox => Ok(gelu_approx(input)),
449 Self::Swish => Ok(swish(input)),
450 Self::Mish => Ok(mish(input)),
451 Self::Softplus(b) => softplus(input, *b),
452 Self::Softsign => Ok(softsign(input)),
453 Self::Hardswish => Ok(hardswish(input)),
454 Self::Hardsigmoid => Ok(hardsigmoid(input)),
455 Self::Sigmoid => Ok(sigmoid(input)),
456 Self::Tanh => Ok(tanh_activation(input)),
457 }
458 }
459
460 pub fn name(&self) -> &'static str {
462 match self {
463 Self::Relu => "relu",
464 Self::Relu6 => "relu6",
465 Self::LeakyRelu(_) => "leaky_relu",
466 Self::Elu(_) => "elu",
467 Self::Selu => "selu",
468 Self::Gelu => "gelu",
469 Self::GeluApprox => "gelu_approx",
470 Self::Swish => "swish",
471 Self::Mish => "mish",
472 Self::Softplus(_) => "softplus",
473 Self::Softsign => "softsign",
474 Self::Hardswish => "hardswish",
475 Self::Hardsigmoid => "hardsigmoid",
476 Self::Sigmoid => "sigmoid",
477 Self::Tanh => "tanh",
478 }
479 }
480
481 pub fn is_monotone(&self) -> bool {
483 matches!(
484 self,
485 Self::Relu
486 | Self::Relu6
487 | Self::LeakyRelu(_)
488 | Self::Elu(_)
489 | Self::Selu
490 | Self::Gelu
491 | Self::GeluApprox
492 | Self::Swish
493 | Self::Softplus(_)
494 | Self::Softsign
495 | Self::Sigmoid
496 | Self::Tanh
497 )
498 }
499
500 pub fn output_range(&self) -> (f64, f64) {
502 match self {
503 Self::Relu => (0.0, f64::INFINITY),
504 Self::Relu6 => (0.0, 6.0),
505 Self::LeakyRelu(_) => (f64::NEG_INFINITY, f64::INFINITY),
506 Self::Elu(_) | Self::Selu => (f64::NEG_INFINITY, f64::INFINITY),
507 Self::Gelu | Self::GeluApprox => (f64::NEG_INFINITY, f64::INFINITY),
508 Self::Swish | Self::Mish => (f64::NEG_INFINITY, f64::INFINITY),
509 Self::Softplus(_) => (0.0, f64::INFINITY),
510 Self::Softsign => (-1.0, 1.0),
511 Self::Hardswish => (f64::NEG_INFINITY, f64::INFINITY),
512 Self::Hardsigmoid => (0.0, 1.0),
513 Self::Sigmoid => (0.0, 1.0),
514 Self::Tanh => (-1.0, 1.0),
515 }
516 }
517}
518
519#[derive(Debug, Clone)]
525pub struct ActivationBenchmark {
526 pub name: String,
527 pub input_size: usize,
528 pub mean_output: f64,
529 pub std_output: f64,
530 pub min_output: f64,
531 pub max_output: f64,
532}
533
534impl ActivationBenchmark {
535 pub fn compute(
537 activation: &ActivationType,
538 input: &ArrayD<f64>,
539 ) -> Result<Self, ActivationError> {
540 if input.is_empty() {
541 return Err(ActivationError::EmptyInput);
542 }
543 let output = activation.apply(input)?;
544 let n = output.len() as f64;
545 let values: Vec<f64> = output.iter().copied().collect();
546
547 let mean = values.iter().sum::<f64>() / n;
548 let variance = values.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / n;
549 let std_output = variance.sqrt();
550 let min_output = values.iter().cloned().fold(f64::INFINITY, f64::min);
551 let max_output = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
552
553 Ok(Self {
554 name: activation.name().to_owned(),
555 input_size: input.len(),
556 mean_output: mean,
557 std_output,
558 min_output,
559 max_output,
560 })
561 }
562
563 pub fn summary(&self) -> String {
565 format!(
566 "{} [n={}] mean={:.4} std={:.4} min={:.4} max={:.4}",
567 self.name,
568 self.input_size,
569 self.mean_output,
570 self.std_output,
571 self.min_output,
572 self.max_output,
573 )
574 }
575}
576
577#[cfg(test)]
582mod tests {
583 use super::*;
584 use scirs2_core::ndarray::{arr1, Array2};
585
586 const EPS: f64 = 1e-6;
587
588 fn arr(v: &[f64]) -> ArrayD<f64> {
589 arr1(v).into_dyn()
590 }
591
592 fn check_close(a: f64, b: f64, eps: f64, msg: &str) {
593 assert!((a - b).abs() < eps, "{msg}: |{a} - {b}| >= {eps}");
594 }
595
596 #[test]
597 fn test_relu_zeros_negative() {
598 let input = arr(&[-3.0, -1.0, 0.0]);
599 let out = relu(&input);
600 for &v in out.iter() {
601 assert_eq!(v, 0.0, "ReLU of non-positive must be 0");
602 }
603 }
604
605 #[test]
606 fn test_relu_positive_unchanged() {
607 let input = arr(&[1.0, 2.5, 100.0]);
608 let out = relu(&input);
609 for (&i, &o) in input.iter().zip(out.iter()) {
610 assert_eq!(i, o, "ReLU must preserve positive values");
611 }
612 }
613
614 #[test]
615 fn test_relu6_clamp() {
616 let input = arr(&[7.0, 6.0, 5.0, -1.0]);
617 let out = relu6(&input);
618 assert_eq!(out[0], 6.0, "values > 6 must be clamped to 6");
619 assert_eq!(out[1], 6.0);
620 assert_eq!(
621 out[2], 5.0,
622 "values <= 6 must be unchanged (if non-negative)"
623 );
624 assert_eq!(out[3], 0.0, "negative values must be 0");
625 }
626
627 #[test]
628 fn test_leaky_relu_negative_slope() {
629 let slope = 0.1;
630 let input = arr(&[-4.0, -1.0, 0.0, 2.0]);
631 let out = leaky_relu(&input, slope);
632 check_close(out[0], -0.4, EPS, "leaky_relu(-4, 0.1)");
633 check_close(out[1], -0.1, EPS, "leaky_relu(-1, 0.1)");
634 check_close(out[2], 0.0, EPS, "leaky_relu(0, 0.1)");
635 check_close(out[3], 2.0, EPS, "leaky_relu(2, 0.1)");
636 }
637
638 #[test]
639 fn test_elu_positive_unchanged() {
640 let input = arr(&[0.5, 1.0, 3.0]);
641 let out = elu(&input, 1.0).expect("elu should succeed");
642 for (&i, &o) in input.iter().zip(out.iter()) {
643 check_close(i, o, EPS, "ELU positive must be identity");
644 }
645 }
646
647 #[test]
648 fn test_elu_negative_approaches_minus_alpha() {
649 let alpha = 1.0;
650 let input = arr(&[-50.0]);
651 let out = elu(&input, alpha).expect("elu should succeed");
652 check_close(
654 out[0],
655 -alpha,
656 1e-10,
657 "ELU large-negative approaches -alpha",
658 );
659 }
660
661 #[test]
662 fn test_selu_scale() {
663 const SCALE: f64 = 1.050_700_987_355_480_5;
664 let input = arr(&[1.0, 2.0, 3.0]);
665 let out = selu(&input);
666 for (&i, &o) in input.iter().zip(out.iter()) {
667 check_close(o, SCALE * i, EPS, "SELU positive = scale * x");
668 }
669 }
670
671 #[test]
672 fn test_gelu_near_zero() {
673 let input = arr(&[0.0]);
674 let out = gelu(&input);
675 check_close(out[0], 0.0, EPS, "gelu(0) must be 0");
676 }
677
678 #[test]
679 fn test_gelu_positive() {
680 let x = 10.0_f64;
682 let result = gelu_scalar(x);
683 check_close(result, x, 1e-4, "gelu(large positive) ≈ large positive");
684 }
685
686 #[test]
687 fn test_swish_zero() {
688 let input = arr(&[0.0]);
689 let out = swish(&input);
690 check_close(out[0], 0.0, EPS, "swish(0) must be 0");
691 }
692
693 #[test]
694 fn test_sigmoid_midpoint() {
695 let input = arr(&[0.0]);
696 let out = sigmoid(&input);
697 check_close(out[0], 0.5, EPS, "sigmoid(0) must be 0.5");
698 }
699
700 #[test]
701 fn test_softmax_sums_to_one() {
702 let data = Array2::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 0.5, 1.5, 2.5, 3.5])
703 .expect("shape ok")
704 .into_dyn();
705 let out = softmax(&data, 1).expect("softmax ok");
706 for row_idx in 0..2_usize {
708 let row_sum: f64 = (0..4).map(|c| out[[row_idx, c]]).sum();
709 check_close(row_sum, 1.0, EPS, "softmax row sum");
710 }
711 }
712
713 #[test]
714 fn test_log_softmax_matches() {
715 let data = arr(&[1.0, 2.0, 3.0, 4.0]);
716 let sm = softmax(&data, 0).expect("softmax ok");
717 let lsm = log_softmax(&data, 0).expect("log_softmax ok");
718 for (&s, &ls) in sm.iter().zip(lsm.iter()) {
719 check_close(s.ln(), ls, 1e-9, "log(softmax) == log_softmax");
720 }
721 }
722
723 #[test]
724 fn test_relu_grad_mask() {
725 let input = arr(&[-2.0, 0.0, 3.0]);
726 let grad = arr(&[1.0, 1.0, 1.0]);
727 let out = relu_grad(&input, &grad);
728 assert_eq!(out[0], 0.0, "grad must be 0 for negative input");
729 assert_eq!(out[1], 0.0, "grad must be 0 for zero input");
730 assert_eq!(out[2], 1.0, "grad must pass through for positive input");
731 }
732
733 #[test]
734 fn test_sigmoid_grad_formula() {
735 let s_out = arr(&[0.5]);
737 let grad = arr(&[2.0]);
738 let out = sigmoid_grad(&s_out, &grad);
739 check_close(out[0], 0.5, EPS, "sigmoid_grad(0.5) * 2.0 == 0.5");
740 }
741
742 #[test]
743 fn test_activation_type_apply_relu() {
744 let input = arr(&[-1.0, 0.0, 1.0, 2.0]);
745 let expected = relu(&input);
746 let got = ActivationType::Relu.apply(&input).expect("apply ok");
747 for (&e, &g) in expected.iter().zip(got.iter()) {
748 check_close(e, g, EPS, "ActivationType::Relu.apply == relu");
749 }
750 }
751
752 #[test]
753 fn test_activation_type_name() {
754 let variants = [
755 ActivationType::Relu,
756 ActivationType::Relu6,
757 ActivationType::LeakyRelu(0.1),
758 ActivationType::Elu(1.0),
759 ActivationType::Selu,
760 ActivationType::Gelu,
761 ActivationType::GeluApprox,
762 ActivationType::Swish,
763 ActivationType::Mish,
764 ActivationType::Softplus(1.0),
765 ActivationType::Softsign,
766 ActivationType::Hardswish,
767 ActivationType::Hardsigmoid,
768 ActivationType::Sigmoid,
769 ActivationType::Tanh,
770 ];
771 for v in &variants {
772 assert!(!v.name().is_empty(), "name must not be empty: {:?}", v);
773 }
774 }
775
776 #[test]
777 fn test_activation_type_output_range() {
778 let variants = [
780 ActivationType::Relu,
781 ActivationType::Relu6,
782 ActivationType::Softsign,
783 ActivationType::Hardsigmoid,
784 ActivationType::Sigmoid,
785 ActivationType::Tanh,
786 ActivationType::Softplus(1.0),
787 ];
788 for v in &variants {
789 let (lo, hi) = v.output_range();
790 assert!(lo <= hi, "output_range lo <= hi for {:?}", v);
791 }
792 let (lo, hi) = ActivationType::Relu6.output_range();
794 assert_eq!(lo, 0.0);
795 assert_eq!(hi, 6.0);
796 let (lo, hi) = ActivationType::Sigmoid.output_range();
797 assert_eq!(lo, 0.0);
798 assert_eq!(hi, 1.0);
799 }
800
801 #[test]
802 fn test_activation_benchmark_compute() {
803 let input = arr(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
804 let bench =
805 ActivationBenchmark::compute(&ActivationType::Relu, &input).expect("benchmark ok");
806 assert_eq!(bench.name, "relu");
807 assert_eq!(bench.input_size, 5);
808 assert!(bench.min_output >= 0.0, "ReLU output must be non-negative");
809 assert!(bench.max_output >= bench.min_output);
810 assert!(!bench.summary().is_empty());
811 }
812
813 #[test]
814 fn test_hardswish_bounds() {
815 let input = arr(&[-10.0, -3.0, 0.0, 3.0, 10.0]);
819 let out = hardswish(&input);
820 check_close(out[0], 0.0, EPS, "hardswish(-10) = 0");
821 check_close(out[1], 0.0, EPS, "hardswish(-3) = 0");
822 check_close(out[2], 0.0, EPS, "hardswish(0) = 0");
825 check_close(out[3], 3.0, EPS, "hardswish(3) = 3");
827 check_close(out[4], 10.0, EPS, "hardswish(10) = 10");
829 }
830}