Skip to main content

cubecl_core/frontend/operation/
base.rs

1use cubecl_ir::{
2    Arithmetic, BinaryOperator, Comparison, ElemType, IndexAssignOperator, IndexOperator,
3    Instruction, ManagedVariable, Operation, Operator, Scope, Type, UnaryOperator, Variable,
4    VariableKind, VectorSize,
5};
6use cubecl_macros::cube;
7
8use crate::{
9    self as cubecl,
10    frontend::{validate_complex_assign_operation, validate_complex_operation},
11    prelude::{CubeIndex, CubeType, Int, NativeExpand, eq, rem},
12};
13
14pub(crate) fn binary_expand<F, Op>(
15    scope: &mut Scope,
16    lhs: ManagedVariable,
17    rhs: ManagedVariable,
18    func: F,
19) -> ManagedVariable
20where
21    F: Fn(BinaryOperator) -> Op,
22    Op: Into<Operation>,
23{
24    let lhs = lhs.consume();
25    let rhs = rhs.consume();
26
27    let item_lhs = lhs.ty;
28    let item_rhs = rhs.ty;
29
30    let vector_size = find_vectorization(item_lhs, item_rhs);
31
32    let item = item_lhs.with_vector_size(vector_size);
33
34    let output = scope.create_local(item);
35    let out = *output;
36
37    let op = func(BinaryOperator { lhs, rhs }).into();
38    validate_complex_operation(scope, &op);
39
40    scope.register(Instruction::new(op, out));
41
42    output
43}
44
45pub(crate) fn index_expand_no_vec<F>(
46    scope: &mut Scope,
47    list: ManagedVariable,
48    index: ManagedVariable,
49    func: F,
50) -> ManagedVariable
51where
52    F: Fn(IndexOperator) -> Operator,
53{
54    let list = list.consume();
55    let index = index.consume();
56
57    let item_lhs = list.ty;
58
59    let item = item_lhs.with_vector_size(0);
60
61    let output = scope.create_local(item);
62    let out = *output;
63
64    let op = func(IndexOperator {
65        list,
66        index,
67        vector_size: 0,
68        unroll_factor: 1,
69    });
70
71    scope.register(Instruction::new(op, out));
72
73    output
74}
75pub(crate) fn index_expand<F, Op>(
76    scope: &mut Scope,
77    list: ManagedVariable,
78    index: ManagedVariable,
79    vector_size: Option<VectorSize>,
80    func: F,
81) -> ManagedVariable
82where
83    F: Fn(IndexOperator) -> Op,
84    Op: Into<Operation>,
85{
86    let list = list.consume();
87    let index = index.consume();
88
89    let item_lhs = list.ty;
90    let item_rhs = index.ty;
91
92    let vec = if let Some(vector_size) = vector_size {
93        vector_size
94    } else {
95        find_vectorization(item_lhs, item_rhs)
96    };
97
98    let item = item_lhs.with_vector_size(vec);
99
100    let output = scope.create_local(item);
101    let out = *output;
102
103    let op = func(IndexOperator {
104        list,
105        index,
106        vector_size: vector_size.unwrap_or(0),
107        unroll_factor: 1,
108    });
109
110    scope.register(Instruction::new(op, out));
111
112    output
113}
114
115pub(crate) fn binary_expand_fixed_output<F>(
116    scope: &mut Scope,
117    lhs: ManagedVariable,
118    rhs: ManagedVariable,
119    out_item: Type,
120    func: F,
121) -> ManagedVariable
122where
123    F: Fn(BinaryOperator) -> Arithmetic,
124{
125    let lhs_var = lhs.consume();
126    let rhs_var = rhs.consume();
127
128    let out = scope.create_local(out_item);
129
130    let out_var = *out;
131
132    let op = func(BinaryOperator {
133        lhs: lhs_var,
134        rhs: rhs_var,
135    })
136    .into();
137    validate_complex_operation(scope, &op);
138
139    scope.register(Instruction::new(op, out_var));
140
141    out
142}
143
144pub(crate) fn cmp_expand<F>(
145    scope: &mut Scope,
146    lhs: ManagedVariable,
147    rhs: ManagedVariable,
148    func: F,
149) -> ManagedVariable
150where
151    F: Fn(BinaryOperator) -> Comparison,
152{
153    let lhs = lhs.consume();
154    let rhs = rhs.consume();
155
156    let item_lhs = lhs.ty;
157    let item_rhs = rhs.ty;
158
159    let vector_size = find_vectorization(item_lhs, item_rhs);
160
161    let out_item = Type::scalar(ElemType::Bool).with_vector_size(vector_size);
162
163    let out = scope.create_local(out_item);
164    let out_var = *out;
165
166    let op = func(BinaryOperator { lhs, rhs }).into();
167    validate_complex_operation(scope, &op);
168
169    scope.register(Instruction::new(op, out_var));
170
171    out
172}
173
174pub(crate) fn assign_op_expand<F, Op>(
175    scope: &mut Scope,
176    lhs: ManagedVariable,
177    rhs: ManagedVariable,
178    func: F,
179) -> ManagedVariable
180where
181    F: Fn(BinaryOperator) -> Op,
182    Op: Into<Operation>,
183{
184    if lhs.is_immutable() {
185        panic!("Can't have a mutable operation on a const variable. Try to use `RuntimeCell`.");
186    }
187    let lhs_var: Variable = *lhs;
188    let rhs: Variable = *rhs;
189
190    let op = func(BinaryOperator { lhs: lhs_var, rhs }).into();
191    validate_complex_assign_operation(scope, &op);
192
193    scope.register(Instruction::new(op, lhs_var));
194
195    lhs
196}
197
198pub fn unary_expand<F, Op>(scope: &mut Scope, input: ManagedVariable, func: F) -> ManagedVariable
199where
200    F: Fn(UnaryOperator) -> Op,
201    Op: Into<Operation>,
202{
203    let input = input.consume();
204    let item = input.ty;
205
206    let out = scope.create_local(item);
207    let out_var = *out;
208
209    let op = func(UnaryOperator { input }).into();
210    validate_complex_operation(scope, &op);
211
212    scope.register(Instruction::new(op, out_var));
213
214    out
215}
216
217pub fn unary_expand_fixed_output<F, Op>(
218    scope: &mut Scope,
219    input: ManagedVariable,
220    out_item: Type,
221    func: F,
222) -> ManagedVariable
223where
224    F: Fn(UnaryOperator) -> Op,
225    Op: Into<Operation>,
226{
227    let input = input.consume();
228    let output = scope.create_local(out_item);
229    let out = *output;
230
231    let op = func(UnaryOperator { input }).into();
232    validate_complex_operation(scope, &op);
233
234    scope.register(Instruction::new(op, out));
235
236    output
237}
238
239pub fn init_expand<F>(
240    scope: &mut Scope,
241    input: ManagedVariable,
242    mutable: bool,
243    func: F,
244) -> ManagedVariable
245where
246    F: Fn(Variable) -> Operation,
247{
248    let input_var: Variable = *input;
249    let item = input.ty;
250
251    let out = if mutable {
252        scope.create_local_mut(item)
253    } else {
254        scope.create_local(item)
255    };
256
257    let out_var = *out;
258
259    let op = func(input_var);
260    scope.register(Instruction::new(op, out_var));
261
262    out
263}
264
265pub(crate) fn find_vectorization(lhs: Type, rhs: Type) -> VectorSize {
266    if matches!(lhs, Type::Scalar(_)) && matches!(rhs, Type::Scalar(_)) {
267        0
268    } else {
269        lhs.vector_size().max(rhs.vector_size())
270    }
271}
272
273pub fn array_assign_binary_op_expand<
274    A: CubeType + CubeIndex,
275    V: CubeType,
276    F: Fn(BinaryOperator) -> Op,
277    Op: Into<Operation>,
278>(
279    scope: &mut Scope,
280    array: NativeExpand<A>,
281    index: NativeExpand<usize>,
282    value: NativeExpand<V>,
283    func: F,
284) where
285    A::Output: CubeType + Sized,
286{
287    let array: ManagedVariable = array.into();
288    let index: ManagedVariable = index.into();
289    let value: ManagedVariable = value.into();
290
291    let array_item = match array.kind {
292        // In that case, the array is a vector.
293        VariableKind::LocalMut { .. } => array.ty.with_vector_size(0),
294        _ => array.ty,
295    };
296    let array_value = scope.create_local(array_item);
297
298    let read = Instruction::new(
299        Operator::Index(IndexOperator {
300            list: *array,
301            index: *index,
302            vector_size: 0,
303            unroll_factor: 1,
304        }),
305        *array_value,
306    );
307    let array_value = array_value.consume();
308    let op_out = scope.create_local(array_item);
309    let calculate = Instruction::new(
310        func(BinaryOperator {
311            lhs: array_value,
312            rhs: *value,
313        }),
314        *op_out,
315    );
316
317    let write = Operator::IndexAssign(IndexAssignOperator {
318        index: *index,
319        value: op_out.consume(),
320        vector_size: 0,
321        unroll_factor: 1,
322    });
323    scope.register(read);
324    scope.register(calculate);
325    scope.register(Instruction::new(write, *array));
326}
327
328pub trait DivCeil: Int + CubeType<ExpandType: DivCeilExpand<Self>> {
329    fn div_ceil(self, divisor: Self) -> Self;
330
331    fn __expand_div_ceil(
332        scope: &mut Scope,
333        a: NativeExpand<Self>,
334        b: NativeExpand<Self>,
335    ) -> NativeExpand<Self> {
336        a.__expand_div_ceil_method(scope, b)
337    }
338}
339
340pub trait DivCeilExpand<E: Int> {
341    fn __expand_div_ceil_method(self, scope: &mut Scope, divisor: Self) -> Self;
342}
343
344impl<E: DivCeil> DivCeilExpand<E> for NativeExpand<E> {
345    fn __expand_div_ceil_method(
346        self,
347        scope: &mut Scope,
348        divisor: NativeExpand<E>,
349    ) -> NativeExpand<E> {
350        div_ceil::expand::<E>(scope, self, divisor)
351    }
352}
353
354macro_rules! impl_div_ceil {
355    ($($ty:ty),*) => {
356        $(
357            impl DivCeil for $ty {
358                #[allow(clippy::manual_div_ceil)] // Need to define div_ceil to use div_ceil!
359                fn div_ceil(self, divisor: Self) -> Self {
360                    (self + divisor - 1) / divisor
361                }
362            }
363        )*
364    };
365}
366
367impl_div_ceil!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize);
368
369impl<E: Int> NativeExpand<E> {
370    pub fn __expand_is_multiple_of_method(
371        self,
372        scope: &mut Scope,
373        factor: NativeExpand<E>,
374    ) -> NativeExpand<bool> {
375        let modulo = rem::expand(scope, self, factor);
376        eq::expand(scope, modulo, E::from_int(0).into())
377    }
378}
379
380#[cube]
381pub fn div_ceil<E: Int>(a: E, b: E) -> E {
382    (a + b - E::new(1)) / b
383}