1use cubecl_ir::{
2 Arithmetic, BinaryOperator, Comparison, ElemType, IndexAssignOperator, IndexOperator,
3 Instruction, ManagedVariable, Operation, Operator, Scope, Type, UnaryOperator, Variable,
4 VariableKind, VectorSize,
5};
6use cubecl_macros::cube;
7
8use crate::{
9 self as cubecl,
10 frontend::{validate_complex_assign_operation, validate_complex_operation},
11 prelude::{CubeIndex, CubeType, Int, NativeExpand, eq, rem},
12};
13
14pub(crate) fn binary_expand<F, Op>(
15 scope: &mut Scope,
16 lhs: ManagedVariable,
17 rhs: ManagedVariable,
18 func: F,
19) -> ManagedVariable
20where
21 F: Fn(BinaryOperator) -> Op,
22 Op: Into<Operation>,
23{
24 let lhs = lhs.consume();
25 let rhs = rhs.consume();
26
27 let item_lhs = lhs.ty;
28 let item_rhs = rhs.ty;
29
30 let vector_size = find_vectorization(item_lhs, item_rhs);
31
32 let item = item_lhs.with_vector_size(vector_size);
33
34 let output = scope.create_local(item);
35 let out = *output;
36
37 let op = func(BinaryOperator { lhs, rhs }).into();
38 validate_complex_operation(scope, &op);
39
40 scope.register(Instruction::new(op, out));
41
42 output
43}
44
45pub(crate) fn index_expand_no_vec<F>(
46 scope: &mut Scope,
47 list: ManagedVariable,
48 index: ManagedVariable,
49 func: F,
50) -> ManagedVariable
51where
52 F: Fn(IndexOperator) -> Operator,
53{
54 let list = list.consume();
55 let index = index.consume();
56
57 let item_lhs = list.ty;
58
59 let item = item_lhs.with_vector_size(0);
60
61 let output = scope.create_local(item);
62 let out = *output;
63
64 let op = func(IndexOperator {
65 list,
66 index,
67 vector_size: 0,
68 unroll_factor: 1,
69 });
70
71 scope.register(Instruction::new(op, out));
72
73 output
74}
75pub(crate) fn index_expand<F, Op>(
76 scope: &mut Scope,
77 list: ManagedVariable,
78 index: ManagedVariable,
79 vector_size: Option<VectorSize>,
80 func: F,
81) -> ManagedVariable
82where
83 F: Fn(IndexOperator) -> Op,
84 Op: Into<Operation>,
85{
86 let list = list.consume();
87 let index = index.consume();
88
89 let item_lhs = list.ty;
90 let item_rhs = index.ty;
91
92 let vec = if let Some(vector_size) = vector_size {
93 vector_size
94 } else {
95 find_vectorization(item_lhs, item_rhs)
96 };
97
98 let item = item_lhs.with_vector_size(vec);
99
100 let output = scope.create_local(item);
101 let out = *output;
102
103 let op = func(IndexOperator {
104 list,
105 index,
106 vector_size: vector_size.unwrap_or(0),
107 unroll_factor: 1,
108 });
109
110 scope.register(Instruction::new(op, out));
111
112 output
113}
114
115pub(crate) fn binary_expand_fixed_output<F>(
116 scope: &mut Scope,
117 lhs: ManagedVariable,
118 rhs: ManagedVariable,
119 out_item: Type,
120 func: F,
121) -> ManagedVariable
122where
123 F: Fn(BinaryOperator) -> Arithmetic,
124{
125 let lhs_var = lhs.consume();
126 let rhs_var = rhs.consume();
127
128 let out = scope.create_local(out_item);
129
130 let out_var = *out;
131
132 let op = func(BinaryOperator {
133 lhs: lhs_var,
134 rhs: rhs_var,
135 })
136 .into();
137 validate_complex_operation(scope, &op);
138
139 scope.register(Instruction::new(op, out_var));
140
141 out
142}
143
144pub(crate) fn cmp_expand<F>(
145 scope: &mut Scope,
146 lhs: ManagedVariable,
147 rhs: ManagedVariable,
148 func: F,
149) -> ManagedVariable
150where
151 F: Fn(BinaryOperator) -> Comparison,
152{
153 let lhs = lhs.consume();
154 let rhs = rhs.consume();
155
156 let item_lhs = lhs.ty;
157 let item_rhs = rhs.ty;
158
159 let vector_size = find_vectorization(item_lhs, item_rhs);
160
161 let out_item = Type::scalar(ElemType::Bool).with_vector_size(vector_size);
162
163 let out = scope.create_local(out_item);
164 let out_var = *out;
165
166 let op = func(BinaryOperator { lhs, rhs }).into();
167 validate_complex_operation(scope, &op);
168
169 scope.register(Instruction::new(op, out_var));
170
171 out
172}
173
174pub(crate) fn assign_op_expand<F, Op>(
175 scope: &mut Scope,
176 lhs: ManagedVariable,
177 rhs: ManagedVariable,
178 func: F,
179) -> ManagedVariable
180where
181 F: Fn(BinaryOperator) -> Op,
182 Op: Into<Operation>,
183{
184 if lhs.is_immutable() {
185 panic!("Can't have a mutable operation on a const variable. Try to use `RuntimeCell`.");
186 }
187 let lhs_var: Variable = *lhs;
188 let rhs: Variable = *rhs;
189
190 let op = func(BinaryOperator { lhs: lhs_var, rhs }).into();
191 validate_complex_assign_operation(scope, &op);
192
193 scope.register(Instruction::new(op, lhs_var));
194
195 lhs
196}
197
198pub fn unary_expand<F, Op>(scope: &mut Scope, input: ManagedVariable, func: F) -> ManagedVariable
199where
200 F: Fn(UnaryOperator) -> Op,
201 Op: Into<Operation>,
202{
203 let input = input.consume();
204 let item = input.ty;
205
206 let out = scope.create_local(item);
207 let out_var = *out;
208
209 let op = func(UnaryOperator { input }).into();
210 validate_complex_operation(scope, &op);
211
212 scope.register(Instruction::new(op, out_var));
213
214 out
215}
216
217pub fn unary_expand_fixed_output<F, Op>(
218 scope: &mut Scope,
219 input: ManagedVariable,
220 out_item: Type,
221 func: F,
222) -> ManagedVariable
223where
224 F: Fn(UnaryOperator) -> Op,
225 Op: Into<Operation>,
226{
227 let input = input.consume();
228 let output = scope.create_local(out_item);
229 let out = *output;
230
231 let op = func(UnaryOperator { input }).into();
232 validate_complex_operation(scope, &op);
233
234 scope.register(Instruction::new(op, out));
235
236 output
237}
238
239pub fn init_expand<F>(
240 scope: &mut Scope,
241 input: ManagedVariable,
242 mutable: bool,
243 func: F,
244) -> ManagedVariable
245where
246 F: Fn(Variable) -> Operation,
247{
248 let input_var: Variable = *input;
249 let item = input.ty;
250
251 let out = if mutable {
252 scope.create_local_mut(item)
253 } else {
254 scope.create_local(item)
255 };
256
257 let out_var = *out;
258
259 let op = func(input_var);
260 scope.register(Instruction::new(op, out_var));
261
262 out
263}
264
265pub(crate) fn find_vectorization(lhs: Type, rhs: Type) -> VectorSize {
266 if matches!(lhs, Type::Scalar(_)) && matches!(rhs, Type::Scalar(_)) {
267 0
268 } else {
269 lhs.vector_size().max(rhs.vector_size())
270 }
271}
272
273pub fn array_assign_binary_op_expand<
274 A: CubeType + CubeIndex,
275 V: CubeType,
276 F: Fn(BinaryOperator) -> Op,
277 Op: Into<Operation>,
278>(
279 scope: &mut Scope,
280 array: NativeExpand<A>,
281 index: NativeExpand<usize>,
282 value: NativeExpand<V>,
283 func: F,
284) where
285 A::Output: CubeType + Sized,
286{
287 let array: ManagedVariable = array.into();
288 let index: ManagedVariable = index.into();
289 let value: ManagedVariable = value.into();
290
291 let array_item = match array.kind {
292 VariableKind::LocalMut { .. } => array.ty.with_vector_size(0),
294 _ => array.ty,
295 };
296 let array_value = scope.create_local(array_item);
297
298 let read = Instruction::new(
299 Operator::Index(IndexOperator {
300 list: *array,
301 index: *index,
302 vector_size: 0,
303 unroll_factor: 1,
304 }),
305 *array_value,
306 );
307 let array_value = array_value.consume();
308 let op_out = scope.create_local(array_item);
309 let calculate = Instruction::new(
310 func(BinaryOperator {
311 lhs: array_value,
312 rhs: *value,
313 }),
314 *op_out,
315 );
316
317 let write = Operator::IndexAssign(IndexAssignOperator {
318 index: *index,
319 value: op_out.consume(),
320 vector_size: 0,
321 unroll_factor: 1,
322 });
323 scope.register(read);
324 scope.register(calculate);
325 scope.register(Instruction::new(write, *array));
326}
327
328pub trait DivCeil: Int + CubeType<ExpandType: DivCeilExpand<Self>> {
329 fn div_ceil(self, divisor: Self) -> Self;
330
331 fn __expand_div_ceil(
332 scope: &mut Scope,
333 a: NativeExpand<Self>,
334 b: NativeExpand<Self>,
335 ) -> NativeExpand<Self> {
336 a.__expand_div_ceil_method(scope, b)
337 }
338}
339
340pub trait DivCeilExpand<E: Int> {
341 fn __expand_div_ceil_method(self, scope: &mut Scope, divisor: Self) -> Self;
342}
343
344impl<E: DivCeil> DivCeilExpand<E> for NativeExpand<E> {
345 fn __expand_div_ceil_method(
346 self,
347 scope: &mut Scope,
348 divisor: NativeExpand<E>,
349 ) -> NativeExpand<E> {
350 div_ceil::expand::<E>(scope, self, divisor)
351 }
352}
353
354macro_rules! impl_div_ceil {
355 ($($ty:ty),*) => {
356 $(
357 impl DivCeil for $ty {
358 #[allow(clippy::manual_div_ceil)] fn div_ceil(self, divisor: Self) -> Self {
360 (self + divisor - 1) / divisor
361 }
362 }
363 )*
364 };
365}
366
367impl_div_ceil!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize);
368
369impl<E: Int> NativeExpand<E> {
370 pub fn __expand_is_multiple_of_method(
371 self,
372 scope: &mut Scope,
373 factor: NativeExpand<E>,
374 ) -> NativeExpand<bool> {
375 let modulo = rem::expand(scope, self, factor);
376 eq::expand(scope, modulo, E::from_int(0).into())
377 }
378}
379
380#[cube]
381pub fn div_ceil<E: Int>(a: E, b: E) -> E {
382 (a + b - E::new(1)) / b
383}