Skip to main content

yscv_kernels/ops/
config.rs

1use rayon::ThreadPool;
2use yscv_tensor::Tensor;
3
4pub const DEFAULT_MATMUL_MIN_PARALLEL_OUTPUT_ELEMENTS: usize = 65_536;
5pub const DEFAULT_MATMUL_MIN_PARALLEL_SHARED_DIM: usize = 128;
6// WHY 262144: 256K floats = 1MB; below this, rayon dispatch overhead (~3-5us) exceeds compute savings.
7pub const DEFAULT_ELEMENTWISE_MIN_PARALLEL_ELEMENTS: usize = 262_144;
8/// Higher threshold for transcendental ops (sigmoid, tanh, exp, etc.)
9/// where per-element compute is heavy enough that threading overhead
10/// is comparable to the gain at smaller sizes (~1M elements).
11#[allow(dead_code)]
12pub const DEFAULT_TRANSCENDENTAL_MIN_PARALLEL_ELEMENTS: usize = 1_048_576;
13// WHY 16384: 64KB per chunk (16K x 4B) fits in L1 cache; enough work per thread to amortize dispatch.
14pub(crate) const PARALLEL_SLICE_CHUNK_ELEMENTS: usize = 16_384;
15
16/// Parallel heuristics for CPU elementwise operations.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub struct ParallelElementwiseConfig {
19    /// Minimum number of tensor elements required before parallel elementwise execution.
20    pub min_parallel_elements: usize,
21}
22
23impl ParallelElementwiseConfig {
24    /// Disable parallel execution and force sequential elementwise execution.
25    pub const fn disabled() -> Self {
26        Self {
27            min_parallel_elements: usize::MAX,
28        }
29    }
30}
31
32impl Default for ParallelElementwiseConfig {
33    fn default() -> Self {
34        Self {
35            min_parallel_elements: DEFAULT_ELEMENTWISE_MIN_PARALLEL_ELEMENTS,
36        }
37    }
38}
39
40/// Parallel heuristics for CPU matmul row-splitting.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct ParallelMatmulConfig {
43    /// Minimum `m * n` output cells needed before row-parallel execution is considered.
44    pub min_parallel_output_elements: usize,
45    /// Minimum shared dimension (`k`) needed before row-parallel execution is considered.
46    pub min_parallel_shared_dim: usize,
47}
48
49impl ParallelMatmulConfig {
50    /// Disable parallel execution and force sequential matmul path.
51    pub const fn disabled() -> Self {
52        Self {
53            min_parallel_output_elements: usize::MAX,
54            min_parallel_shared_dim: usize::MAX,
55        }
56    }
57}
58
59impl Default for ParallelMatmulConfig {
60    fn default() -> Self {
61        Self {
62            min_parallel_output_elements: DEFAULT_MATMUL_MIN_PARALLEL_OUTPUT_ELEMENTS,
63            min_parallel_shared_dim: DEFAULT_MATMUL_MIN_PARALLEL_SHARED_DIM,
64        }
65    }
66}
67
68#[derive(Debug, Clone, Copy)]
69pub(crate) struct MatMulPlan {
70    pub m: usize,
71    pub k: usize,
72    pub n: usize,
73    pub output_len: usize,
74}
75
76#[derive(Debug, Clone, Copy)]
77pub(crate) struct Pool2dPlan {
78    pub batch: usize,
79    pub in_h: usize,
80    pub in_w: usize,
81    pub channels: usize,
82    pub out_h: usize,
83    pub out_w: usize,
84    pub kernel_h: usize,
85    pub kernel_w: usize,
86    pub stride_h: usize,
87    pub stride_w: usize,
88    pub output_len: usize,
89}
90
91#[derive(Debug, Clone, Copy)]
92pub struct Pool2dSpec {
93    pub kernel_h: usize,
94    pub kernel_w: usize,
95    pub stride_h: usize,
96    pub stride_w: usize,
97}
98
99#[derive(Debug, Clone, Copy)]
100pub(crate) struct Conv2dPlan {
101    pub batch: usize,
102    pub in_h: usize,
103    pub in_w: usize,
104    pub in_channels: usize,
105    pub out_h: usize,
106    pub out_w: usize,
107    pub out_channels: usize,
108    pub kernel_h: usize,
109    pub kernel_w: usize,
110    pub stride_h: usize,
111    pub stride_w: usize,
112    pub output_len: usize,
113}
114
115#[derive(Debug, Clone, Copy)]
116pub struct Conv2dSpec {
117    pub stride_h: usize,
118    pub stride_w: usize,
119}
120
121#[derive(Debug, Clone, Copy)]
122pub(crate) struct DepthwiseConv2dPlan {
123    pub batch: usize,
124    pub in_h: usize,
125    pub in_w: usize,
126    pub channels: usize,
127    pub depth_multiplier: usize,
128    pub out_h: usize,
129    pub out_w: usize,
130    pub out_channels: usize,
131    pub kernel_h: usize,
132    pub kernel_w: usize,
133    pub stride_h: usize,
134    pub stride_w: usize,
135    pub output_len: usize,
136}
137
138#[derive(Debug, Clone, Copy)]
139pub struct DepthwiseConv2dSpec {
140    pub stride_h: usize,
141    pub stride_w: usize,
142}
143
144#[derive(Debug, Clone, Copy)]
145pub struct SeparableConv2dSpec {
146    pub stride_h: usize,
147    pub stride_w: usize,
148}
149
150#[derive(Debug, Clone, Copy)]
151pub struct SeparableConv2dKernels<'a> {
152    pub depthwise_kernel: &'a Tensor,
153    pub depthwise_bias: Option<&'a Tensor>,
154    pub pointwise_kernel: &'a Tensor,
155    pub pointwise_bias: Option<&'a Tensor>,
156}
157
158#[derive(Debug, Clone, Copy)]
159pub struct BatchNorm2dTensors<'a> {
160    pub gamma: &'a Tensor,
161    pub beta: &'a Tensor,
162    pub mean: &'a Tensor,
163    pub variance: &'a Tensor,
164    pub epsilon: f32,
165}
166
167#[derive(Debug, Clone, Copy)]
168pub struct LayerNormLastDimTensors<'a> {
169    pub gamma: &'a Tensor,
170    pub beta: &'a Tensor,
171    pub epsilon: f32,
172}
173
174#[derive(Debug, Clone, Copy)]
175pub(crate) struct BatchNorm2dPlan {
176    pub batch: usize,
177    pub height: usize,
178    pub width: usize,
179    pub channels: usize,
180    pub output_len: usize,
181}
182
183#[derive(Debug, Clone, Copy)]
184pub(crate) struct SoftmaxPlan {
185    pub row_len: usize,
186    pub output_len: usize,
187}
188
189#[derive(Debug, Clone)]
190pub(crate) struct LogSumExpPlan {
191    pub row_len: usize,
192    pub output_shape: Vec<usize>,
193    pub output_len: usize,
194}
195
196#[derive(Debug, Clone, Copy)]
197pub(crate) struct LayerNormPlan {
198    pub row_len: usize,
199    pub output_len: usize,
200}
201
202#[derive(Debug, Clone, Copy)]
203pub struct GroupNorm2dTensors<'a> {
204    pub gamma: &'a Tensor,
205    pub beta: &'a Tensor,
206    pub num_groups: usize,
207    pub epsilon: f32,
208}
209
210#[derive(Debug, Clone, Copy)]
211pub(crate) struct GroupNorm2dPlan {
212    pub batch: usize,
213    pub height: usize,
214    pub width: usize,
215    pub channels: usize,
216    pub num_groups: usize,
217    pub channels_per_group: usize,
218    pub output_len: usize,
219}
220
221#[derive(Debug, Clone, Copy)]
222pub struct RmsNormLastDimTensors<'a> {
223    pub gamma: &'a Tensor,
224    pub epsilon: f32,
225}
226
227#[derive(Debug, Clone, Copy)]
228pub(crate) struct RmsNormPlan {
229    pub row_len: usize,
230    pub output_len: usize,
231}
232
233#[derive(Debug, Clone, Copy)]
234pub enum BinaryKind {
235    Add,
236    Sub,
237    Mul,
238}
239
240#[derive(Debug, Clone, Copy)]
241pub(crate) enum Pool2dKind {
242    Max,
243    Avg,
244}
245
246pub(crate) fn should_parallelize_len(
247    len: usize,
248    min_parallel_len: usize,
249    thread_pool: Option<&ThreadPool>,
250) -> bool {
251    if cfg!(miri) {
252        return false;
253    }
254    if len < min_parallel_len {
255        return false;
256    }
257    available_threads(thread_pool) > 1
258}
259
260pub(crate) fn available_threads(thread_pool: Option<&ThreadPool>) -> usize {
261    thread_pool
262        .map(ThreadPool::current_num_threads)
263        .unwrap_or_else(rayon::current_num_threads)
264}