sapient_backends_cpu/kernels/
elementwise.rs1use sapient_core::error::{Result, SapientError};
7use sapient_core::{DType, Tensor};
8
9fn unary_f32<F: Fn(f32) -> f32>(x: &Tensor, f: F) -> Result<Tensor> {
13 if x.dtype() != DType::F32 {
14 return Err(SapientError::TypeMismatch {
15 expected: "f32".into(),
16 got: x.dtype().to_string(),
17 });
18 }
19 let data: Vec<f32> = x.to_f32_cow().iter().map(|&v| f(v)).collect();
20 Tensor::from_f32(&data, x.shape().clone())
21}
22
23fn binary_f32<F: Fn(f32, f32) -> f32>(a: &Tensor, b: &Tensor, f: F) -> Result<Tensor> {
25 let a_cow = a.to_f32_cow();
27 let a_data = a_cow.as_ref();
28 let b_cow = b.to_f32_cow();
29 let b_data = b_cow.as_ref();
30
31 let (out, shape) = if a_data.len() == b_data.len() {
32 let out: Vec<f32> = a_data
33 .iter()
34 .zip(b_data.iter())
35 .map(|(&x, &y)| f(x, y))
36 .collect();
37 (out, a.shape().clone())
38 } else if b_data.len() == 1 {
39 let scalar = b_data[0];
40 let out: Vec<f32> = a_data.iter().map(|&x| f(x, scalar)).collect();
41 (out, a.shape().clone())
42 } else if a_data.len() == 1 {
43 let scalar = a_data[0];
44 let out: Vec<f32> = b_data.iter().map(|&y| f(scalar, y)).collect();
45 (out, b.shape().clone())
46 } else {
47 return Err(SapientError::ShapeMismatch {
48 expected: a.shape().dims().to_vec(),
49 got: b.shape().dims().to_vec(),
50 });
51 };
52
53 Tensor::from_f32(&out, shape)
54}
55
56pub fn add(a: &Tensor, b: &Tensor) -> Result<Tensor> {
59 binary_f32(a, b, |x, y| x + y)
60}
61pub fn sub(a: &Tensor, b: &Tensor) -> Result<Tensor> {
62 binary_f32(a, b, |x, y| x - y)
63}
64pub fn mul(a: &Tensor, b: &Tensor) -> Result<Tensor> {
65 binary_f32(a, b, |x, y| x * y)
66}
67pub fn div(a: &Tensor, b: &Tensor) -> Result<Tensor> {
68 binary_f32(a, b, |x, y| x / y)
69}
70pub fn pow(a: &Tensor, b: &Tensor) -> Result<Tensor> {
71 binary_f32(a, b, |x, y| x.powf(y))
72}
73
74pub fn neg(x: &Tensor) -> Result<Tensor> {
75 unary_f32(x, |v| -v)
76}
77pub fn abs(x: &Tensor) -> Result<Tensor> {
78 unary_f32(x, |v| v.abs())
79}
80pub fn sqrt(x: &Tensor) -> Result<Tensor> {
81 unary_f32(x, |v| v.sqrt())
82}
83pub fn exp(x: &Tensor) -> Result<Tensor> {
84 unary_f32(x, |v| v.exp())
85}
86pub fn log(x: &Tensor) -> Result<Tensor> {
87 unary_f32(x, |v| v.ln())
88}
89pub fn erf(x: &Tensor) -> Result<Tensor> {
90 unary_f32(x, erf_approx)
91}
92pub fn floor(x: &Tensor) -> Result<Tensor> {
93 unary_f32(x, |v| v.floor())
94}
95pub fn ceil(x: &Tensor) -> Result<Tensor> {
96 unary_f32(x, |v| v.ceil())
97}
98pub fn round(x: &Tensor) -> Result<Tensor> {
99 unary_f32(x, |v| v.round())
100}
101
102pub fn relu(x: &Tensor) -> Result<Tensor> {
105 unary_f32(x, |v| v.max(0.0))
106}
107
108pub fn sigmoid(x: &Tensor) -> Result<Tensor> {
109 unary_f32(x, |v| 1.0 / (1.0 + (-v).exp()))
110}
111
112pub fn tanh_act(x: &Tensor) -> Result<Tensor> {
113 unary_f32(x, |v| v.tanh())
114}
115
116pub fn gelu(x: &Tensor) -> Result<Tensor> {
118 const SQRT_2_OVER_PI: f32 = 0.797_884_56;
119 const COEF: f32 = 0.044_715;
120 unary_f32(x, |v| {
121 let inner = SQRT_2_OVER_PI * (v + COEF * v * v * v);
122 0.5 * v * (1.0 + inner.tanh())
123 })
124}
125
126pub fn silu(x: &Tensor) -> Result<Tensor> {
128 unary_f32(x, |v| v / (1.0 + (-v).exp()))
129}
130
131pub fn hard_swish(x: &Tensor) -> Result<Tensor> {
133 unary_f32(x, |v| v * (v + 3.0).clamp(0.0, 6.0) / 6.0)
134}
135
136pub fn leaky_relu(x: &Tensor, alpha: f32) -> Result<Tensor> {
137 unary_f32(x, |v| if v >= 0.0 { v } else { alpha * v })
138}
139
140pub fn clip(x: &Tensor, min: Option<f32>, max: Option<f32>) -> Result<Tensor> {
141 unary_f32(x, |v| {
142 let v = min.map_or(v, |lo| v.max(lo));
143 max.map_or(v, |hi| v.min(hi))
144 })
145}
146
147fn erf_approx(x: f32) -> f32 {
150 let sign = x.signum();
151 let x = x.abs();
152 let t = 1.0 / (1.0 + 0.327_591_1 * x);
154 let y = 1.0
155 - (0.254_829_59
156 + (-0.284_496_74 + (1.421_413_74 + (-1.453_152_03 + 1.061_405_43 * t) * t) * t) * t)
157 * t
158 * (-x * x).exp();
159 sign * y
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 fn t(data: &[f32]) -> Tensor {
167 Tensor::from_f32(data, vec![data.len()]).unwrap()
168 }
169
170 #[test]
171 fn test_add() {
172 assert!(
173 (add(&t(&[1.0, 2.0]), &t(&[3.0, 4.0]))
174 .unwrap()
175 .as_f32_slice()[0]
176 - 4.0)
177 .abs()
178 < 1e-6
179 );
180 }
181 #[test]
182 fn test_relu() {
183 let r = relu(&t(&[-1.0, 0.0, 1.0])).unwrap();
184 let d = r.as_f32_slice();
185 assert_eq!(d, &[0.0, 0.0, 1.0]);
186 }
187 #[test]
188 fn test_sigmoid() {
189 let v = sigmoid(&t(&[0.0])).unwrap().as_f32_slice()[0];
190 assert!((v - 0.5).abs() < 1e-6);
191 }
192 #[test]
193 fn test_gelu() {
194 let v = gelu(&t(&[0.0])).unwrap().as_f32_slice()[0];
195 assert!(v.abs() < 1e-5);
196 }
197 #[test]
198 fn test_erf() {
199 let v = erf_approx(0.0);
200 assert!(v.abs() < 1e-6, "erf(0) should be ~0, got {v}");
201 }
202 #[test]
203 fn test_scalar_broadcast() {
204 let a = t(&[1.0, 2.0, 3.0]);
205 let b = t(&[2.0]);
206 let r = mul(&a, &b).unwrap();
207 assert_eq!(r.as_f32_slice(), &[2.0, 4.0, 6.0]);
208 }
209}