Skip to main content

sapient_backends_cpu/kernels/
elementwise.rs

1//! Element-wise CPU kernels — arithmetic, activations, and mathematical ops.
2//!
3//! All kernels operate on F32 tensors. Binary ops support same-shape operands
4//! only (broadcasting handled by the dispatch layer after shape inference).
5
6use sapient_core::error::{Result, SapientError};
7use sapient_core::{DType, Tensor};
8
9// ── Helper ────────────────────────────────────────────────────────────────────
10
11/// Apply a unary f32 function element-wise.
12fn unary_f32<F: Fn(f32) -> f32>(x: &Tensor, f: F) -> Result<Tensor> {
13    if x.dtype() != DType::F32 {
14        return Err(SapientError::TypeMismatch {
15            expected: "f32".into(),
16            got: x.dtype().to_string(),
17        });
18    }
19    let data: Vec<f32> = x.to_f32_cow().iter().map(|&v| f(v)).collect();
20    Tensor::from_f32(&data, x.shape().clone())
21}
22
23/// Apply a binary f32 function element-wise (same shape only).
24fn binary_f32<F: Fn(f32, f32) -> f32>(a: &Tensor, b: &Tensor, f: F) -> Result<Tensor> {
25    // Handle scalar broadcast (numel == 1).
26    let a_cow = a.to_f32_cow();
27    let a_data = a_cow.as_ref();
28    let b_cow = b.to_f32_cow();
29    let b_data = b_cow.as_ref();
30
31    let (out, shape) = if a_data.len() == b_data.len() {
32        let out: Vec<f32> = a_data
33            .iter()
34            .zip(b_data.iter())
35            .map(|(&x, &y)| f(x, y))
36            .collect();
37        (out, a.shape().clone())
38    } else if b_data.len() == 1 {
39        let scalar = b_data[0];
40        let out: Vec<f32> = a_data.iter().map(|&x| f(x, scalar)).collect();
41        (out, a.shape().clone())
42    } else if a_data.len() == 1 {
43        let scalar = a_data[0];
44        let out: Vec<f32> = b_data.iter().map(|&y| f(scalar, y)).collect();
45        (out, b.shape().clone())
46    } else {
47        return Err(SapientError::ShapeMismatch {
48            expected: a.shape().dims().to_vec(),
49            got: b.shape().dims().to_vec(),
50        });
51    };
52
53    Tensor::from_f32(&out, shape)
54}
55
56// ── Arithmetic ────────────────────────────────────────────────────────────────
57
58pub fn add(a: &Tensor, b: &Tensor) -> Result<Tensor> {
59    binary_f32(a, b, |x, y| x + y)
60}
61pub fn sub(a: &Tensor, b: &Tensor) -> Result<Tensor> {
62    binary_f32(a, b, |x, y| x - y)
63}
64pub fn mul(a: &Tensor, b: &Tensor) -> Result<Tensor> {
65    binary_f32(a, b, |x, y| x * y)
66}
67pub fn div(a: &Tensor, b: &Tensor) -> Result<Tensor> {
68    binary_f32(a, b, |x, y| x / y)
69}
70pub fn pow(a: &Tensor, b: &Tensor) -> Result<Tensor> {
71    binary_f32(a, b, |x, y| x.powf(y))
72}
73
74pub fn neg(x: &Tensor) -> Result<Tensor> {
75    unary_f32(x, |v| -v)
76}
77pub fn abs(x: &Tensor) -> Result<Tensor> {
78    unary_f32(x, |v| v.abs())
79}
80pub fn sqrt(x: &Tensor) -> Result<Tensor> {
81    unary_f32(x, |v| v.sqrt())
82}
83pub fn exp(x: &Tensor) -> Result<Tensor> {
84    unary_f32(x, |v| v.exp())
85}
86pub fn log(x: &Tensor) -> Result<Tensor> {
87    unary_f32(x, |v| v.ln())
88}
89pub fn erf(x: &Tensor) -> Result<Tensor> {
90    unary_f32(x, erf_approx)
91}
92pub fn floor(x: &Tensor) -> Result<Tensor> {
93    unary_f32(x, |v| v.floor())
94}
95pub fn ceil(x: &Tensor) -> Result<Tensor> {
96    unary_f32(x, |v| v.ceil())
97}
98pub fn round(x: &Tensor) -> Result<Tensor> {
99    unary_f32(x, |v| v.round())
100}
101
102// ── Activations ───────────────────────────────────────────────────────────────
103
104pub fn relu(x: &Tensor) -> Result<Tensor> {
105    unary_f32(x, |v| v.max(0.0))
106}
107
108pub fn sigmoid(x: &Tensor) -> Result<Tensor> {
109    unary_f32(x, |v| 1.0 / (1.0 + (-v).exp()))
110}
111
112pub fn tanh_act(x: &Tensor) -> Result<Tensor> {
113    unary_f32(x, |v| v.tanh())
114}
115
116/// GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
117pub fn gelu(x: &Tensor) -> Result<Tensor> {
118    const SQRT_2_OVER_PI: f32 = 0.797_884_56;
119    const COEF: f32 = 0.044_715;
120    unary_f32(x, |v| {
121        let inner = SQRT_2_OVER_PI * (v + COEF * v * v * v);
122        0.5 * v * (1.0 + inner.tanh())
123    })
124}
125
126/// SiLU / Swish: x * sigmoid(x)
127pub fn silu(x: &Tensor) -> Result<Tensor> {
128    unary_f32(x, |v| v / (1.0 + (-v).exp()))
129}
130
131/// Hard Swish: x * relu6(x + 3) / 6
132pub fn hard_swish(x: &Tensor) -> Result<Tensor> {
133    unary_f32(x, |v| v * (v + 3.0).clamp(0.0, 6.0) / 6.0)
134}
135
136pub fn leaky_relu(x: &Tensor, alpha: f32) -> Result<Tensor> {
137    unary_f32(x, |v| if v >= 0.0 { v } else { alpha * v })
138}
139
140pub fn clip(x: &Tensor, min: Option<f32>, max: Option<f32>) -> Result<Tensor> {
141    unary_f32(x, |v| {
142        let v = min.map_or(v, |lo| v.max(lo));
143        max.map_or(v, |hi| v.min(hi))
144    })
145}
146
147// ── Erf approximation (Abramowitz & Stegun) ───────────────────────────────────
148
149fn erf_approx(x: f32) -> f32 {
150    let sign = x.signum();
151    let x = x.abs();
152    // Rational approximation — max error ~1.5e-7.
153    let t = 1.0 / (1.0 + 0.327_591_1 * x);
154    let y = 1.0
155        - (0.254_829_59
156            + (-0.284_496_74 + (1.421_413_74 + (-1.453_152_03 + 1.061_405_43 * t) * t) * t) * t)
157            * t
158            * (-x * x).exp();
159    sign * y
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    fn t(data: &[f32]) -> Tensor {
167        Tensor::from_f32(data, vec![data.len()]).unwrap()
168    }
169
170    #[test]
171    fn test_add() {
172        assert!(
173            (add(&t(&[1.0, 2.0]), &t(&[3.0, 4.0]))
174                .unwrap()
175                .as_f32_slice()[0]
176                - 4.0)
177                .abs()
178                < 1e-6
179        );
180    }
181    #[test]
182    fn test_relu() {
183        let r = relu(&t(&[-1.0, 0.0, 1.0])).unwrap();
184        let d = r.as_f32_slice();
185        assert_eq!(d, &[0.0, 0.0, 1.0]);
186    }
187    #[test]
188    fn test_sigmoid() {
189        let v = sigmoid(&t(&[0.0])).unwrap().as_f32_slice()[0];
190        assert!((v - 0.5).abs() < 1e-6);
191    }
192    #[test]
193    fn test_gelu() {
194        let v = gelu(&t(&[0.0])).unwrap().as_f32_slice()[0];
195        assert!(v.abs() < 1e-5);
196    }
197    #[test]
198    fn test_erf() {
199        let v = erf_approx(0.0);
200        assert!(v.abs() < 1e-6, "erf(0) should be ~0, got {v}");
201    }
202    #[test]
203    fn test_scalar_broadcast() {
204        let a = t(&[1.0, 2.0, 3.0]);
205        let b = t(&[2.0]);
206        let r = mul(&a, &b).unwrap();
207        assert_eq!(r.as_f32_slice(), &[2.0, 4.0, 6.0]);
208    }
209}