Skip to main content

yscv_kernels/ops/
elementwise.rs

1use rayon::{ThreadPool, prelude::*};
2use yscv_tensor::{AlignedVec, Tensor};
3
4use super::super::error::KernelError;
5use super::config::{
6    BinaryKind, PARALLEL_SLICE_CHUNK_ELEMENTS, ParallelElementwiseConfig, should_parallelize_len,
7};
8use super::simd::{
9    binary_same_shape_dispatch, exp_slice_dispatch, relu_slice_dispatch, relu_to_slice_dispatch,
10    sigmoid_slice_dispatch, silu_slice_dispatch, tanh_slice_dispatch,
11};
12
13// GCD low-overhead parallelism (macOS: ~0.3µs dispatch, Linux: scoped threads).
14#[allow(unsafe_code)]
15mod par {
16    #[cfg(target_os = "macos")]
17    use std::ffi::c_void;
18
19    #[cfg(target_os = "macos")]
20    #[allow(unsafe_code)]
21    unsafe extern "C" {
22        fn dispatch_get_global_queue(identifier: isize, flags: usize) -> *const c_void;
23        fn dispatch_apply_f(
24            iterations: usize,
25            queue: *const c_void,
26            context: *mut c_void,
27            work: unsafe extern "C" fn(*mut c_void, usize),
28        );
29    }
30
31    #[cfg(target_os = "macos")]
32    #[inline]
33    #[allow(unsafe_code)]
34    pub fn parallel_for<F: Fn(usize) + Sync>(n: usize, f: F) {
35        #[allow(unsafe_code)]
36        unsafe extern "C" fn call<F: Fn(usize) + Sync>(ctx: *mut c_void, i: usize) {
37            unsafe {
38                (*(ctx as *const F))(i);
39            }
40        }
41        let queue = unsafe { dispatch_get_global_queue(0, 0) };
42        unsafe {
43            dispatch_apply_f(n, queue, &f as *const F as *mut c_void, call::<F>);
44        }
45    }
46
47    #[cfg(not(target_os = "macos"))]
48    #[inline]
49    pub fn parallel_for<F: Fn(usize) + Sync + Send>(n: usize, f: F) {
50        if n <= 1 {
51            for i in 0..n {
52                f(i);
53            }
54            return;
55        }
56        // Use rayon global thread pool — threads are pre-spawned, ~0.5µs dispatch.
57        use rayon::prelude::*;
58        (0..n).into_par_iter().for_each(f);
59    }
60}
61
62/// Elementwise ReLU activation. GCD-parallelized for large tensors.
63#[inline]
64#[allow(unsafe_code)]
65pub fn relu(input: &Tensor) -> Tensor {
66    let input_data = input.data();
67    let len = input_data.len();
68    let mut output = AlignedVec::<f32>::uninitialized(len);
69
70    const PAR_THRESH: usize = 100_000;
71    if len >= PAR_THRESH {
72        let n_chunks = std::thread::available_parallelism()
73            .map(|p| p.get())
74            .unwrap_or(4);
75        let chunk = len.div_ceil(n_chunks);
76        let in_ptr = input_data.as_ptr() as usize;
77        let out_ptr = output.as_mut_ptr() as usize;
78        par::parallel_for(n_chunks, |t| {
79            let start = t * chunk;
80            let end = (start + chunk).min(len);
81            let inp = unsafe {
82                std::slice::from_raw_parts((in_ptr as *const f32).add(start), end - start)
83            };
84            let out = unsafe {
85                std::slice::from_raw_parts_mut((out_ptr as *mut f32).add(start), end - start)
86            };
87            relu_to_slice_dispatch(inp, out);
88        });
89    } else {
90        relu_to_slice_dispatch(input_data, &mut output);
91    }
92
93    Tensor::from_raw_parts(input.shape(), input.strides(), output)
94}
95
96/// In-place ReLU activation: clamps negative values to zero.
97#[inline]
98pub fn relu_inplace(tensor: &mut Tensor) {
99    relu_slice_dispatch(tensor.data_mut());
100}
101
102/// ReLU writing into pre-allocated output tensor. Zero allocation overhead.
103#[inline]
104pub fn relu_out(input: &Tensor, output: &mut Tensor) {
105    debug_assert_eq!(input.shape(), output.shape());
106    relu_to_slice_dispatch(input.data(), output.data_mut());
107}
108
109/// Elementwise sigmoid activation.
110pub fn sigmoid(input: &Tensor) -> Tensor {
111    sigmoid_with_config(input, ParallelElementwiseConfig::disabled())
112}
113
114/// Elementwise ReLU activation with explicit parallelization heuristics.
115pub fn relu_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
116    relu_with_config_and_pool(input, config, None)
117}
118
119/// Elementwise sigmoid activation with explicit parallelization heuristics.
120pub fn sigmoid_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
121    sigmoid_with_config_and_pool(input, config, None)
122}
123
124/// # Safety
125/// `AlignedVec::uninitialized` allocates without zeroing. `relu_to_slice_dispatch`
126/// writes every element before anything reads from the buffer.
127#[allow(unsafe_code)]
128pub fn relu_with_config_and_pool(
129    input: &Tensor,
130    config: ParallelElementwiseConfig,
131    thread_pool: Option<&ThreadPool>,
132) -> Tensor {
133    let input_data = input.data();
134    let len = input_data.len();
135    let mut output = AlignedVec::<f32>::uninitialized(len);
136    if should_parallelize_len(len, config.min_parallel_elements, thread_pool) {
137        let mut work = || {
138            output
139                .par_chunks_mut(PARALLEL_SLICE_CHUNK_ELEMENTS)
140                .enumerate()
141                .for_each(|(chunk_idx, out_chunk)| {
142                    let start = chunk_idx * PARALLEL_SLICE_CHUNK_ELEMENTS;
143                    let end = start + out_chunk.len();
144                    relu_to_slice_dispatch(&input_data[start..end], out_chunk);
145                });
146        };
147        if let Some(pool) = thread_pool {
148            pool.install(work);
149        } else {
150            work();
151        }
152    } else {
153        relu_to_slice_dispatch(input_data, &mut output);
154    }
155    Tensor::from_raw_parts(input.shape(), input.strides(), output)
156}
157
158/// # Safety
159/// `AlignedVec::uninitialized` allocates without zeroing. `sigmoid_slice_dispatch`
160/// writes every element before anything reads from the buffer.
161#[allow(unsafe_code)]
162pub fn sigmoid_with_config_and_pool(
163    input: &Tensor,
164    _config: ParallelElementwiseConfig,
165    _thread_pool: Option<&ThreadPool>,
166) -> Tensor {
167    let input_data = input.data();
168    let len = input_data.len();
169    // Sigmoid uses a heavy polynomial exp + divide per element.
170    // Single-threaded SIMD is faster than rayon chunking for ≤4M elements
171    // due to dispatch overhead (62 tasks × 50µs > compute savings).
172    let mut output = AlignedVec::<f32>::uninitialized(len);
173    sigmoid_slice_dispatch(input_data, &mut output);
174    Tensor::from_raw_parts(input.shape(), input.strides(), output)
175}
176
177/// Elementwise exp activation.
178pub fn exp(input: &Tensor) -> Tensor {
179    exp_with_config(input, ParallelElementwiseConfig::disabled())
180}
181
182/// Elementwise exp activation with explicit parallelization heuristics.
183pub fn exp_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
184    exp_with_config_and_pool(input, config, None)
185}
186
187/// # Safety
188/// `AlignedVec::uninitialized` allocates without zeroing. `exp_slice_dispatch`
189/// writes every element before anything reads from the buffer.
190#[allow(unsafe_code)]
191pub fn exp_with_config_and_pool(
192    input: &Tensor,
193    config: ParallelElementwiseConfig,
194    thread_pool: Option<&ThreadPool>,
195) -> Tensor {
196    let input_data = input.data();
197    let len = input_data.len();
198    let mut output = AlignedVec::<f32>::uninitialized(len);
199    if should_parallelize_len(len, config.min_parallel_elements, thread_pool) {
200        let mut work = || {
201            output
202                .par_chunks_mut(PARALLEL_SLICE_CHUNK_ELEMENTS)
203                .enumerate()
204                .for_each(|(chunk_idx, out_chunk)| {
205                    let start = chunk_idx * PARALLEL_SLICE_CHUNK_ELEMENTS;
206                    let end = start + out_chunk.len();
207                    exp_slice_dispatch(&input_data[start..end], out_chunk);
208                });
209        };
210        if let Some(pool) = thread_pool {
211            pool.install(work);
212        } else {
213            work();
214        }
215    } else {
216        exp_slice_dispatch(input_data, &mut output);
217    }
218    Tensor::from_raw_parts(input.shape(), input.strides(), output)
219}
220
221/// Elementwise tanh activation.
222pub fn tanh_act(input: &Tensor) -> Tensor {
223    tanh_act_with_config(input, ParallelElementwiseConfig::disabled())
224}
225
226/// Elementwise tanh activation with explicit parallelization heuristics.
227pub fn tanh_act_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
228    tanh_act_with_config_and_pool(input, config, None)
229}
230
231/// # Safety
232/// `AlignedVec::uninitialized` allocates without zeroing. `tanh_slice_dispatch`
233/// writes every element before anything reads from the buffer.
234#[allow(unsafe_code)]
235pub fn tanh_act_with_config_and_pool(
236    input: &Tensor,
237    _config: ParallelElementwiseConfig,
238    _thread_pool: Option<&ThreadPool>,
239) -> Tensor {
240    let input_data = input.data();
241    let len = input_data.len();
242    // Tanh uses a heavy polynomial exp + divide per element.
243    // Single-threaded SIMD is faster than rayon chunking for typical sizes.
244    let mut output = AlignedVec::<f32>::uninitialized(len);
245    tanh_slice_dispatch(input_data, &mut output);
246    Tensor::from_raw_parts(input.shape(), input.strides(), output)
247}
248
249const ACTIVATION_PARALLEL_THRESHOLD: usize = 65536;
250const ACTIVATION_CHUNK_SIZE: usize = 8192;
251
252/// Elementwise GELU activation (fast approximation): `x * sigmoid(1.702 * x)`.
253///
254/// # Safety
255/// `AlignedVec::uninitialized` allocates without zeroing. `gelu_slice_out`
256/// writes every element before anything reads from the buffer.
257#[allow(unsafe_code)]
258pub fn gelu(input: &Tensor) -> Tensor {
259    let src = input.data();
260    let len = src.len();
261    let mut output = AlignedVec::<f32>::uninitialized(len);
262    if len >= ACTIVATION_PARALLEL_THRESHOLD {
263        output
264            .par_chunks_mut(ACTIVATION_CHUNK_SIZE)
265            .enumerate()
266            .for_each(|(ci, out_chunk)| {
267                let start = ci * ACTIVATION_CHUNK_SIZE;
268                gelu_slice_out(&src[start..start + out_chunk.len()], out_chunk);
269            });
270    } else {
271        gelu_slice_out(src, &mut output);
272    }
273    Tensor::from_raw_parts(input.shape(), input.strides(), output)
274}
275
276/// Elementwise SiLU (Swish) activation: `x * sigmoid(x)`.
277///
278/// Uses fused SIMD kernel (sigmoid + multiply in one pass) to halve memory bandwidth.
279pub fn silu(input: &Tensor) -> Tensor {
280    silu_with_config(input, ParallelElementwiseConfig::disabled())
281}
282
283/// Elementwise SiLU activation with explicit parallelization heuristics.
284pub fn silu_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
285    silu_with_config_and_pool(input, config, None)
286}
287
288/// # Safety
289/// `AlignedVec::uninitialized` allocates without zeroing. `silu_slice_dispatch`
290/// writes every element before anything reads from the buffer.
291#[allow(unsafe_code)]
292pub fn silu_with_config_and_pool(
293    input: &Tensor,
294    _config: ParallelElementwiseConfig,
295    _thread_pool: Option<&ThreadPool>,
296) -> Tensor {
297    let input_data = input.data();
298    let len = input_data.len();
299    // SiLU uses a heavy polynomial exp + divide + multiply per element.
300    // Single-threaded SIMD is faster than rayon chunking for typical sizes.
301    let mut output = AlignedVec::<f32>::uninitialized(len);
302    silu_slice_dispatch(input_data, &mut output);
303    Tensor::from_raw_parts(input.shape(), input.strides(), output)
304}
305
306/// Elementwise Mish activation: `x * tanh(softplus(x))` = `x * tanh(ln(1 + exp(x)))`.
307pub fn mish(input: &Tensor) -> Tensor {
308    let mut output = input.clone();
309    let data = output.data_mut();
310    if data.len() >= ACTIVATION_PARALLEL_THRESHOLD {
311        data.par_chunks_mut(ACTIVATION_CHUNK_SIZE)
312            .for_each(mish_slice);
313    } else {
314        mish_slice(data);
315    }
316    output
317}
318
319fn gelu_slice_out(src: &[f32], dst: &mut [f32]) {
320    for i in 0..src.len() {
321        let x = src[i];
322        let a = 1.702 * x;
323        let ea = (-a).exp();
324        let s = 1.0 / (1.0 + ea);
325        dst[i] = x * s;
326    }
327}
328
329fn mish_slice(data: &mut [f32]) {
330    for i in 0..data.len() {
331        let x = data[i];
332        let sp = (1.0 + x.exp()).ln();
333        data[i] = x * sp.tanh();
334    }
335}
336
337/// Elementwise add with optional parallel same-shape execution.
338pub fn add_with_config(
339    lhs: &Tensor,
340    rhs: &Tensor,
341    config: ParallelElementwiseConfig,
342) -> Result<Tensor, KernelError> {
343    add_with_config_and_pool(lhs, rhs, config, None)
344}
345
346pub fn add_with_config_and_pool(
347    lhs: &Tensor,
348    rhs: &Tensor,
349    config: ParallelElementwiseConfig,
350    thread_pool: Option<&ThreadPool>,
351) -> Result<Tensor, KernelError> {
352    binary_with_config_and_pool(lhs, rhs, config, thread_pool, BinaryKind::Add)
353}
354
355/// Elementwise subtract with optional parallel same-shape execution.
356pub fn sub_with_config(
357    lhs: &Tensor,
358    rhs: &Tensor,
359    config: ParallelElementwiseConfig,
360) -> Result<Tensor, KernelError> {
361    sub_with_config_and_pool(lhs, rhs, config, None)
362}
363
364pub fn sub_with_config_and_pool(
365    lhs: &Tensor,
366    rhs: &Tensor,
367    config: ParallelElementwiseConfig,
368    thread_pool: Option<&ThreadPool>,
369) -> Result<Tensor, KernelError> {
370    binary_with_config_and_pool(lhs, rhs, config, thread_pool, BinaryKind::Sub)
371}
372
373/// Elementwise multiply with optional parallel same-shape execution.
374pub fn mul_with_config(
375    lhs: &Tensor,
376    rhs: &Tensor,
377    config: ParallelElementwiseConfig,
378) -> Result<Tensor, KernelError> {
379    mul_with_config_and_pool(lhs, rhs, config, None)
380}
381
382pub fn mul_with_config_and_pool(
383    lhs: &Tensor,
384    rhs: &Tensor,
385    config: ParallelElementwiseConfig,
386    thread_pool: Option<&ThreadPool>,
387) -> Result<Tensor, KernelError> {
388    binary_with_config_and_pool(lhs, rhs, config, thread_pool, BinaryKind::Mul)
389}
390
391fn binary_with_config_and_pool(
392    lhs: &Tensor,
393    rhs: &Tensor,
394    config: ParallelElementwiseConfig,
395    thread_pool: Option<&ThreadPool>,
396    kind: BinaryKind,
397) -> Result<Tensor, KernelError> {
398    if lhs.shape() != rhs.shape() {
399        return binary_fallback(lhs, rhs, kind);
400    }
401
402    let left = lhs.data();
403    let right = rhs.data();
404    let shape = lhs.shape().to_vec();
405    let mut output = AlignedVec::<f32>::uninitialized(left.len());
406
407    if should_parallelize_len(left.len(), config.min_parallel_elements, thread_pool) {
408        let mut work = || {
409            output
410                .par_chunks_mut(PARALLEL_SLICE_CHUNK_ELEMENTS)
411                .enumerate()
412                .for_each(|(chunk_idx, out_chunk)| {
413                    let start = chunk_idx * PARALLEL_SLICE_CHUNK_ELEMENTS;
414                    let end = start + out_chunk.len();
415                    binary_same_shape_dispatch(
416                        &left[start..end],
417                        &right[start..end],
418                        out_chunk,
419                        kind,
420                    );
421                });
422        };
423
424        if let Some(pool) = thread_pool {
425            pool.install(work);
426        } else {
427            work();
428        }
429    } else {
430        binary_same_shape_dispatch(left, right, &mut output, kind);
431    }
432
433    Tensor::from_aligned(shape, output).map_err(Into::into)
434}
435
436fn binary_fallback(lhs: &Tensor, rhs: &Tensor, kind: BinaryKind) -> Result<Tensor, KernelError> {
437    match kind {
438        BinaryKind::Add => lhs.add(rhs),
439        BinaryKind::Sub => lhs.sub(rhs),
440        BinaryKind::Mul => lhs.mul(rhs),
441    }
442    .map_err(Into::into)
443}