1#![deny(unsafe_code)]
3
4pub const CRATE_ID: &str = "yscv-kernels";
5
6#[path = "backend.rs"]
7mod backend;
8#[path = "error.rs"]
9mod error;
10#[cfg(feature = "gpu")]
11#[path = "gpu_backend.rs"]
12mod gpu_backend;
13#[cfg(feature = "gpu")]
14#[path = "gpu_session.rs"]
15mod gpu_session;
16#[cfg(feature = "gpu")]
17#[path = "multi_device.rs"]
18mod multi_device;
19#[path = "ops/mod.rs"]
20mod ops;
21
22pub use backend::{
23 Backend, BackwardOps, BatchNorm2dParams, CpuBackend, GroupNormNhwcParams,
24 LayerNormLastDimParams, RmsNormLastDimParams, SeparableConv2dParams, ThreadedCpuBackend,
25 ThreadedCpuBackendConfig, add, add_with_config, avg_pool2d_nhwc, avg_pool2d_nhwc_with_config,
26 batch_norm2d_nhwc, batch_norm2d_nhwc_with_config, conv2d_nhwc, conv2d_nhwc_with_config,
27 deformable_conv2d_nhwc, depthwise_conv2d_nhwc, depthwise_conv2d_nhwc_with_config, dropout,
28 embedding_lookup, exp, exp_with_config, flash_attention, gelu, group_norm_nhwc,
29 group_norm_nhwc_with_config, layer_norm_last_dim, layer_norm_last_dim_with_config,
30 log_softmax_last_dim, log_softmax_last_dim_with_config, logsumexp_last_dim,
31 logsumexp_last_dim_with_config, matmul_2d, matmul_2d_sequential, matmul_2d_with_config,
32 matmul_2d_with_threads, max_pool2d_nhwc, max_pool2d_nhwc_with_config, mish, mul,
33 mul_with_config, relu, relu_inplace, relu_with_config, rms_norm_last_dim,
34 rms_norm_last_dim_with_config, scaled_dot_product_attention, separable_conv2d_nhwc,
35 separable_conv2d_nhwc_with_config, sigmoid, sigmoid_with_config, silu, softmax_last_dim,
36 softmax_last_dim_with_config, sub, sub_with_config, tanh_act, tanh_act_with_config,
37 transpose_conv2d_nhwc,
38};
39pub use error::KernelError;
40#[cfg(feature = "gpu")]
41pub use gpu_backend::{GpuBackend, gpu_batch_norm, gpu_layer_norm, gpu_transpose};
42#[cfg(feature = "gpu")]
43pub use gpu_session::GpuSession;
44#[cfg(feature = "gpu")]
45pub use multi_device::{
46 GpuApiBackend, GpuDeviceInfo, GpuDeviceType, MultiGpuBackend, SchedulingStrategy,
47 enumerate_gpu_devices,
48};
49pub use ops::{
50 BinaryKind, DEFAULT_ELEMENTWISE_MIN_PARALLEL_ELEMENTS,
51 DEFAULT_MATMUL_MIN_PARALLEL_OUTPUT_ELEMENTS, DEFAULT_MATMUL_MIN_PARALLEL_SHARED_DIM,
52 ParallelElementwiseConfig, ParallelMatmulConfig, add_reduce_dispatch, add_with_config_and_pool,
53 avg_pool2d_nhwc_with_config_and_pool, batch_norm2d_nhwc_with_config_and_pool,
54 binary_same_shape_dispatch, conv2d_nhwc_with_config_and_pool, conv3d,
55 depthwise_conv2d_nhwc_with_config_and_pool, exp_slice_dispatch, exp_with_config_and_pool,
56 fma_slice_dispatch, group_norm_nhwc_with_config_and_pool,
57 layer_norm_last_dim_with_config_and_pool, log_softmax_last_dim_with_config_and_pool,
58 logsumexp_last_dim_with_config_and_pool, matmul_2d_with_config_and_pool, matmul_row_dispatch,
59 max_pool2d_nhwc_with_config_and_pool, max_reduce_dispatch, mul_with_config_and_pool, relu_out,
60 relu_slice_dispatch, relu_to_slice_dispatch, relu_with_config_and_pool,
61 rms_norm_last_dim_with_config_and_pool, separable_conv2d_nhwc_with_config_and_pool,
62 sigmoid_slice_dispatch, sigmoid_with_config_and_pool, softmax_last_dim_with_config_and_pool,
63 sub_exp_slice_dispatch, sub_with_config_and_pool, tanh_act_with_config_and_pool,
64 tanh_slice_dispatch,
65};
66
67#[path = "tests/mod.rs"]
68#[cfg(test)]
69mod tests;