1use rayon::ThreadPool;
2use yscv_tensor::Tensor;
3
4pub const DEFAULT_MATMUL_MIN_PARALLEL_OUTPUT_ELEMENTS: usize = 65_536;
5pub const DEFAULT_MATMUL_MIN_PARALLEL_SHARED_DIM: usize = 128;
6pub const DEFAULT_ELEMENTWISE_MIN_PARALLEL_ELEMENTS: usize = 262_144;
8#[allow(dead_code)]
12pub const DEFAULT_TRANSCENDENTAL_MIN_PARALLEL_ELEMENTS: usize = 1_048_576;
13pub(crate) const PARALLEL_SLICE_CHUNK_ELEMENTS: usize = 16_384;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub struct ParallelElementwiseConfig {
19 pub min_parallel_elements: usize,
21}
22
23impl ParallelElementwiseConfig {
24 pub const fn disabled() -> Self {
26 Self {
27 min_parallel_elements: usize::MAX,
28 }
29 }
30}
31
32impl Default for ParallelElementwiseConfig {
33 fn default() -> Self {
34 Self {
35 min_parallel_elements: DEFAULT_ELEMENTWISE_MIN_PARALLEL_ELEMENTS,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct ParallelMatmulConfig {
43 pub min_parallel_output_elements: usize,
45 pub min_parallel_shared_dim: usize,
47}
48
49impl ParallelMatmulConfig {
50 pub const fn disabled() -> Self {
52 Self {
53 min_parallel_output_elements: usize::MAX,
54 min_parallel_shared_dim: usize::MAX,
55 }
56 }
57}
58
59impl Default for ParallelMatmulConfig {
60 fn default() -> Self {
61 Self {
62 min_parallel_output_elements: DEFAULT_MATMUL_MIN_PARALLEL_OUTPUT_ELEMENTS,
63 min_parallel_shared_dim: DEFAULT_MATMUL_MIN_PARALLEL_SHARED_DIM,
64 }
65 }
66}
67
68#[derive(Debug, Clone, Copy)]
69pub(crate) struct MatMulPlan {
70 pub m: usize,
71 pub k: usize,
72 pub n: usize,
73 pub output_len: usize,
74}
75
76#[derive(Debug, Clone, Copy)]
77pub(crate) struct Pool2dPlan {
78 pub batch: usize,
79 pub in_h: usize,
80 pub in_w: usize,
81 pub channels: usize,
82 pub out_h: usize,
83 pub out_w: usize,
84 pub kernel_h: usize,
85 pub kernel_w: usize,
86 pub stride_h: usize,
87 pub stride_w: usize,
88 pub output_len: usize,
89}
90
91#[derive(Debug, Clone, Copy)]
92pub struct Pool2dSpec {
93 pub kernel_h: usize,
94 pub kernel_w: usize,
95 pub stride_h: usize,
96 pub stride_w: usize,
97}
98
99#[derive(Debug, Clone, Copy)]
100pub(crate) struct Conv2dPlan {
101 pub batch: usize,
102 pub in_h: usize,
103 pub in_w: usize,
104 pub in_channels: usize,
105 pub out_h: usize,
106 pub out_w: usize,
107 pub out_channels: usize,
108 pub kernel_h: usize,
109 pub kernel_w: usize,
110 pub stride_h: usize,
111 pub stride_w: usize,
112 pub output_len: usize,
113}
114
115#[derive(Debug, Clone, Copy)]
116pub struct Conv2dSpec {
117 pub stride_h: usize,
118 pub stride_w: usize,
119}
120
121#[derive(Debug, Clone, Copy)]
122pub(crate) struct DepthwiseConv2dPlan {
123 pub batch: usize,
124 pub in_h: usize,
125 pub in_w: usize,
126 pub channels: usize,
127 pub depth_multiplier: usize,
128 pub out_h: usize,
129 pub out_w: usize,
130 pub out_channels: usize,
131 pub kernel_h: usize,
132 pub kernel_w: usize,
133 pub stride_h: usize,
134 pub stride_w: usize,
135 pub output_len: usize,
136}
137
138#[derive(Debug, Clone, Copy)]
139pub struct DepthwiseConv2dSpec {
140 pub stride_h: usize,
141 pub stride_w: usize,
142}
143
144#[derive(Debug, Clone, Copy)]
145pub struct SeparableConv2dSpec {
146 pub stride_h: usize,
147 pub stride_w: usize,
148}
149
150#[derive(Debug, Clone, Copy)]
151pub struct SeparableConv2dKernels<'a> {
152 pub depthwise_kernel: &'a Tensor,
153 pub depthwise_bias: Option<&'a Tensor>,
154 pub pointwise_kernel: &'a Tensor,
155 pub pointwise_bias: Option<&'a Tensor>,
156}
157
158#[derive(Debug, Clone, Copy)]
159pub struct BatchNorm2dTensors<'a> {
160 pub gamma: &'a Tensor,
161 pub beta: &'a Tensor,
162 pub mean: &'a Tensor,
163 pub variance: &'a Tensor,
164 pub epsilon: f32,
165}
166
167#[derive(Debug, Clone, Copy)]
168pub struct LayerNormLastDimTensors<'a> {
169 pub gamma: &'a Tensor,
170 pub beta: &'a Tensor,
171 pub epsilon: f32,
172}
173
174#[derive(Debug, Clone, Copy)]
175pub(crate) struct BatchNorm2dPlan {
176 pub batch: usize,
177 pub height: usize,
178 pub width: usize,
179 pub channels: usize,
180 pub output_len: usize,
181}
182
183#[derive(Debug, Clone, Copy)]
184pub(crate) struct SoftmaxPlan {
185 pub row_len: usize,
186 pub output_len: usize,
187}
188
189#[derive(Debug, Clone)]
190pub(crate) struct LogSumExpPlan {
191 pub row_len: usize,
192 pub output_shape: Vec<usize>,
193 pub output_len: usize,
194}
195
196#[derive(Debug, Clone, Copy)]
197pub(crate) struct LayerNormPlan {
198 pub row_len: usize,
199 pub output_len: usize,
200}
201
202#[derive(Debug, Clone, Copy)]
203pub struct GroupNorm2dTensors<'a> {
204 pub gamma: &'a Tensor,
205 pub beta: &'a Tensor,
206 pub num_groups: usize,
207 pub epsilon: f32,
208}
209
210#[derive(Debug, Clone, Copy)]
211pub(crate) struct GroupNorm2dPlan {
212 pub batch: usize,
213 pub height: usize,
214 pub width: usize,
215 pub channels: usize,
216 pub num_groups: usize,
217 pub channels_per_group: usize,
218 pub output_len: usize,
219}
220
221#[derive(Debug, Clone, Copy)]
222pub struct RmsNormLastDimTensors<'a> {
223 pub gamma: &'a Tensor,
224 pub epsilon: f32,
225}
226
227#[derive(Debug, Clone, Copy)]
228pub(crate) struct RmsNormPlan {
229 pub row_len: usize,
230 pub output_len: usize,
231}
232
233#[derive(Debug, Clone, Copy)]
234pub enum BinaryKind {
235 Add,
236 Sub,
237 Mul,
238}
239
240#[derive(Debug, Clone, Copy)]
241pub(crate) enum Pool2dKind {
242 Max,
243 Avg,
244}
245
246pub(crate) fn should_parallelize_len(
247 len: usize,
248 min_parallel_len: usize,
249 thread_pool: Option<&ThreadPool>,
250) -> bool {
251 if cfg!(miri) {
252 return false;
253 }
254 if len < min_parallel_len {
255 return false;
256 }
257 available_threads(thread_pool) > 1
258}
259
260pub(crate) fn available_threads(thread_pool: Option<&ThreadPool>) -> usize {
261 thread_pool
262 .map(ThreadPool::current_num_threads)
263 .unwrap_or_else(rayon::current_num_threads)
264}