Skip to main content

cubecl_core/frontend/operation/
binary.rs

1use crate as cubecl;
2use crate::ir::{Arithmetic, Bitwise, ManagedVariable, Operator, Scope};
3use crate::{
4    flex32,
5    frontend::{CubePrimitive, NativeExpand},
6    prelude::*,
7};
8use crate::{frontend::CubeType, tf32};
9use crate::{
10    frontend::operation::base::{binary_expand, binary_expand_fixed_output},
11    unexpanded,
12};
13use core::{cmp::Ordering, ops::*};
14use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
15use cubecl_ir::ClampOperator;
16use cubecl_macros::derive_expand;
17use half::{bf16, f16};
18
19pub mod add {
20    use super::*;
21
22    pub fn expand<C: CubePrimitive>(
23        scope: &mut Scope,
24        lhs: NativeExpand<C>,
25        rhs: NativeExpand<C>,
26    ) -> NativeExpand<C> {
27        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
28    }
29}
30
31pub mod sub {
32    use cubecl_ir::{ConstantValue, Variable};
33
34    use super::*;
35
36    pub fn expand<C: CubePrimitive>(
37        scope: &mut Scope,
38        lhs: NativeExpand<C>,
39        rhs: NativeExpand<C>,
40    ) -> NativeExpand<C> {
41        // Dirty hack to enable slice destructuring with trailing patterns on `Sequence`
42        match (lhs.expand.as_const(), rhs.expand.as_const()) {
43            (Some(ConstantValue::UInt(lhs_val)), Some(ConstantValue::UInt(rhs_val))) => {
44                let item_lhs = lhs.expand.ty;
45                let item_rhs = rhs.expand.ty;
46
47                let vector_size = find_vectorization(item_lhs, item_rhs);
48
49                let item = item_lhs.with_vector_size(vector_size);
50                let value = (lhs_val - rhs_val).into();
51                ManagedVariable::Plain(Variable::constant(value, item)).into()
52            }
53            _ => binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into(),
54        }
55    }
56}
57
58pub mod mul {
59    use super::*;
60
61    pub fn expand<C: CubePrimitive>(
62        scope: &mut Scope,
63        lhs: NativeExpand<C>,
64        rhs: NativeExpand<C>,
65    ) -> NativeExpand<C> {
66        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
67    }
68}
69
70pub mod div {
71    use super::*;
72
73    pub fn expand<C: CubePrimitive>(
74        scope: &mut Scope,
75        lhs: NativeExpand<C>,
76        rhs: NativeExpand<C>,
77    ) -> NativeExpand<C> {
78        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
79    }
80}
81
82pub mod rem {
83    use super::*;
84
85    pub fn expand<C: CubePrimitive>(
86        scope: &mut Scope,
87        lhs: NativeExpand<C>,
88        rhs: NativeExpand<C>,
89    ) -> NativeExpand<C> {
90        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
91    }
92}
93
94pub mod and {
95    use super::*;
96
97    pub fn expand<C: CubePrimitive>(
98        scope: &mut Scope,
99        lhs: NativeExpand<C>,
100        rhs: NativeExpand<C>,
101    ) -> NativeExpand<bool> {
102        binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
103    }
104}
105
106pub mod bitand {
107    use super::*;
108
109    pub fn expand<C: CubePrimitive>(
110        scope: &mut Scope,
111        lhs: NativeExpand<C>,
112        rhs: NativeExpand<C>,
113    ) -> NativeExpand<C> {
114        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
115    }
116}
117
118pub mod bitor {
119    use super::*;
120
121    pub fn expand<C: CubePrimitive>(
122        scope: &mut Scope,
123        lhs: NativeExpand<C>,
124        rhs: NativeExpand<C>,
125    ) -> NativeExpand<C> {
126        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
127    }
128}
129
130pub mod or {
131    use super::*;
132
133    pub fn expand<C: CubePrimitive>(
134        scope: &mut Scope,
135        lhs: NativeExpand<C>,
136        rhs: NativeExpand<C>,
137    ) -> NativeExpand<bool> {
138        binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
139    }
140}
141
142pub mod bitxor {
143    use super::*;
144
145    pub fn expand<C: CubePrimitive>(
146        scope: &mut Scope,
147        lhs: NativeExpand<C>,
148        rhs: NativeExpand<C>,
149    ) -> NativeExpand<C> {
150        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
151    }
152}
153
154pub mod shl {
155    use super::*;
156
157    pub fn expand<C: CubePrimitive>(
158        scope: &mut Scope,
159        lhs: NativeExpand<C>,
160        rhs: NativeExpand<C>,
161    ) -> NativeExpand<C> {
162        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
163    }
164}
165
166pub mod shr {
167    use super::*;
168
169    pub fn expand<C: CubePrimitive>(
170        scope: &mut Scope,
171        lhs: NativeExpand<C>,
172        rhs: NativeExpand<C>,
173    ) -> NativeExpand<C> {
174        binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
175    }
176}
177
178pub mod clamp {
179    use super::*;
180
181    pub fn expand<C: PartialOrd + CubePrimitive>(
182        scope: &mut Scope,
183        input: NativeExpand<C>,
184        min: NativeExpand<C>,
185        max: NativeExpand<C>,
186    ) -> NativeExpand<C> {
187        unary_expand(scope, input.into(), |op| {
188            Arithmetic::Clamp(ClampOperator {
189                input: op.input,
190                min_value: *min.expand,
191                max_value: *max.expand,
192            })
193        })
194        .into()
195    }
196}
197
198pub mod clamp_max {
199    use super::*;
200
201    pub fn expand<C: PartialOrd + CubePrimitive>(
202        scope: &mut Scope,
203        lhs: NativeExpand<C>,
204        rhs: NativeExpand<C>,
205    ) -> NativeExpand<C> {
206        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
207    }
208}
209
210pub mod clamp_min {
211    use super::*;
212
213    pub fn expand<C: PartialOrd + CubePrimitive>(
214        scope: &mut Scope,
215        lhs: NativeExpand<C>,
216        rhs: NativeExpand<C>,
217    ) -> NativeExpand<C> {
218        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
219    }
220}
221
222/// The minimum of two values, not requiring `Ord`. Provided for clarity in certain cases, though
223/// `clamp_max` may sometimes be more clear.
224pub fn min<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
225    clamp_max(lhs, rhs)
226}
227
228pub mod min {
229    use super::*;
230
231    pub fn expand<C: PartialOrd + CubePrimitive>(
232        scope: &mut Scope,
233        lhs: NativeExpand<C>,
234        rhs: NativeExpand<C>,
235    ) -> NativeExpand<C> {
236        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
237    }
238}
239
240/// The maximum of two values, not requiring `Ord`. Provided for clarity in certain cases, though
241/// `clamp_min` may sometimes be more clear.
242pub fn max<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
243    clamp_min(lhs, rhs)
244}
245
246pub mod max {
247    use super::*;
248
249    pub fn expand<C: PartialOrd + CubePrimitive>(
250        scope: &mut Scope,
251        lhs: NativeExpand<C>,
252        rhs: NativeExpand<C>,
253    ) -> NativeExpand<C> {
254        binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
255    }
256}
257
258/// For binary functions without special syntax
259macro_rules! impl_binary_func {
260    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
261        paste::paste! {
262            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
263                fn $method_name(self, _rhs: Self) -> Self {
264                    unexpanded!()
265                }
266
267                fn [<__expand_ $method_name>](
268                    scope: &mut Scope,
269                    lhs: NativeExpand<Self>,
270                    rhs: NativeExpand<Self>,
271                ) -> NativeExpand<Self> {
272                    lhs.[<__expand_ $method_name _method>](scope, rhs)
273                }
274            }
275
276            pub trait [<$trait_name Expand>] {
277                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self;
278            }
279
280            $(impl $trait_name for $type {})*
281            impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for NativeExpand<T> {
282                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self {
283                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
284                }
285            }
286        }
287    }
288}
289
290macro_rules! impl_binary_func_scalar_out {
291    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
292        paste::paste! {
293            pub trait $trait_name: CubePrimitive
294                + CubeType<ExpandType: [<$trait_name Expand>]
295                + CubePrimitiveExpand<Scalar = NativeExpand<Self::Scalar>>>
296                + Sized {
297                fn $method_name(self, _rhs: Self) -> Self::Scalar {
298                    unexpanded!()
299                }
300
301                fn [<__expand_ $method_name>](
302                    scope: &mut Scope,
303                    lhs: NativeExpand<Self>,
304                    rhs: NativeExpand<Self>,
305                ) -> NativeExpand<Self::Scalar> {
306                    lhs.[<__expand_ $method_name _method>](scope, rhs)
307                }
308            }
309
310            pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
311                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self::Scalar;
312            }
313
314            $(impl $trait_name for $type {})*
315            impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for NativeExpand<T> {
316                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self::Scalar {
317                    let lhs: ManagedVariable = self.into();
318                    let item = lhs.ty.with_vector_size(0);
319                    binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
320                }
321            }
322        }
323    }
324}
325
326macro_rules! impl_binary_func_mixed_types {
327    ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
328        paste::paste! {
329            pub trait $trait_name<Rhs: CubePrimitive + CubeType<ExpandType: Into<ManagedVariable>> + Sized>:
330                CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]<Rhs>> + Sized {
331                fn $method_name(self, _rhs: Rhs) -> Self {
332                    unexpanded!()
333                }
334
335                fn [<__expand_ $method_name>](
336                    scope: &mut Scope,
337                    lhs: NativeExpand<Self>,
338                    rhs: NativeExpand<Rhs>,
339                ) -> NativeExpand<Self> {
340                    binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
341                }
342            }
343
344            pub trait [<$trait_name Expand>]<Rhs: CubeType>{
345                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Rhs::ExpandType) -> Self;
346            }
347
348            $(impl $trait_name<$rhs_ty> for $type {})*
349            impl<Rhs: CubePrimitive, T: CubePrimitive + $trait_name<Rhs>> [<$trait_name Expand>]<Rhs> for NativeExpand<T> {
350                fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: NativeExpand<Rhs>) -> Self {
351                    binary_expand(scope, self.into(), rhs.into(), $operator).into()
352                }
353            }
354        }
355    }
356}
357
358macro_rules! impl_core_binop {
359    ($trait: ident, $method: ident, $op: expr) => {
360        paste::paste! {
361            pub trait [<Cube $trait>]: $trait<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
362                fn [<__expand_ $method>](
363                    scope: &mut Scope,
364                    lhs: NativeExpand<Self>,
365                    rhs: NativeExpand<Self>,
366                ) -> NativeExpand<Self> {
367                    lhs.[<__expand_ $method _method>](scope, rhs)
368                }
369            }
370
371            pub trait [<$trait Expand>] {
372                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self;
373            }
374
375            impl<T: $trait<Output = T> + CubePrimitive> [<Cube $trait>] for T {}
376            impl<T: $trait<Output = T> + CubePrimitive> [<$trait Expand>] for NativeExpand<T> {
377                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self {
378                    binary_expand(scope, self.into(), rhs.into(), $op).into()
379                }
380            }
381        }
382    };
383}
384
385macro_rules! impl_core_assign_binop {
386    ($trait: ident, $method: ident, $op: expr) => {
387        paste::paste! {
388            pub trait [<Cube $trait>]: $trait + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
389                fn [<__expand_ $method>](
390                    scope: &mut Scope,
391                    lhs: NativeExpand<Self>,
392                    rhs: NativeExpand<Self>,
393                ) {
394                    lhs.[<__expand_ $method _method>](scope, rhs)
395                }
396            }
397
398            pub trait [<$trait Expand>] {
399                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self);
400            }
401
402            impl<T: $trait + CubePrimitive> [<Cube $trait>] for T {}
403            impl<T: $trait + CubePrimitive> [<$trait Expand>] for NativeExpand<T> {
404                fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) {
405                    assign_op_expand(scope, self.into(), rhs.into(), $op);
406                }
407            }
408        }
409    };
410}
411
412impl_core_binop!(Add, add, Arithmetic::Add);
413impl_core_binop!(Sub, sub, Arithmetic::Sub);
414impl_core_binop!(Mul, mul, Arithmetic::Mul);
415impl_core_binop!(Div, mul, Arithmetic::Div);
416impl_core_binop!(Rem, rem, Arithmetic::Modulo);
417
418impl_core_assign_binop!(AddAssign, add_assign, Arithmetic::Add);
419impl_core_assign_binop!(SubAssign, sub_assign, Arithmetic::Sub);
420impl_core_assign_binop!(MulAssign, mul_assign, Arithmetic::Mul);
421impl_core_assign_binop!(DivAssign, div_assign, Arithmetic::Div);
422impl_core_assign_binop!(RemAssign, rem_assign, Arithmetic::Modulo);
423
424#[derive_expand(CubeType, CubeTypeMut, IntoRuntime)]
425#[cube(runtime_variants, no_constructors)]
426pub enum Ordering {
427    Less = -1,
428    Equal = 0,
429    Greater = 1,
430}
431
432fn ordering_disc(name: &'static str) -> NativeExpand<i32> {
433    OrderingExpand::discriminant_of(name).into()
434}
435
436#[allow(non_snake_case)]
437pub trait CubeOrdering {
438    fn Less() -> Ordering {
439        Ordering::Less
440    }
441    fn Equal() -> Ordering {
442        Ordering::Equal
443    }
444    fn Greater() -> Ordering {
445        Ordering::Greater
446    }
447    fn __expand_Less(_scope: &mut Scope) -> OrderingExpand {
448        OrderingExpand {
449            discriminant: ordering_disc("Less"),
450            value: (),
451        }
452    }
453    fn __expand_Equal(_scope: &mut Scope) -> OrderingExpand {
454        OrderingExpand {
455            discriminant: ordering_disc("Equal"),
456            value: (),
457        }
458    }
459    fn __expand_Greater(_scope: &mut Scope) -> OrderingExpand {
460        OrderingExpand {
461            discriminant: ordering_disc("Greater"),
462            value: (),
463        }
464    }
465}
466
467impl CubeOrdering for Ordering {}
468
469pub trait CubeOrd: Ord + CubeType<ExpandType: OrdExpand> + Sized {
470    fn __expand_cmp(
471        scope: &mut Scope,
472        lhs: Self::ExpandType,
473        rhs: Self::ExpandType,
474    ) -> OrderingExpand {
475        lhs.__expand_cmp_method(scope, rhs)
476    }
477
478    fn __expand_min(
479        scope: &mut Scope,
480        lhs: Self::ExpandType,
481        rhs: Self::ExpandType,
482    ) -> Self::ExpandType {
483        lhs.__expand_min_method(scope, rhs)
484    }
485
486    fn __expand_max(
487        scope: &mut Scope,
488        lhs: Self::ExpandType,
489        rhs: Self::ExpandType,
490    ) -> Self::ExpandType {
491        lhs.__expand_max_method(scope, rhs)
492    }
493
494    fn __expand_clamp(
495        scope: &mut Scope,
496        lhs: Self::ExpandType,
497        min: Self::ExpandType,
498        max: Self::ExpandType,
499    ) -> Self::ExpandType {
500        lhs.__expand_clamp_method(scope, min, max)
501    }
502}
503pub trait OrdExpand {
504    fn __expand_cmp_method(self, scope: &mut Scope, rhs: Self) -> OrderingExpand;
505    fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self;
506    fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self;
507    fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self;
508}
509
510impl<T: Ord + CubePrimitive> CubeOrd for T {}
511impl<T: Ord + CubePrimitive> OrdExpand for NativeExpand<T> {
512    fn __expand_cmp_method(self, scope: &mut Scope, rhs: Self) -> OrderingExpand {
513        let lhs_lt_rhs = lt::expand(scope, self.clone(), rhs.clone());
514        let lhs_gt_rhs = gt::expand(scope, self, rhs);
515        let less = ordering_disc("Less");
516        let equal = ordering_disc("Equal");
517        let greater = ordering_disc("Greater");
518        let eq_or_gt = select::expand(scope, lhs_gt_rhs, greater, equal);
519        let discriminant = select::expand(scope, lhs_lt_rhs, less, eq_or_gt);
520        OrderingExpand {
521            discriminant,
522            value: (),
523        }
524    }
525    fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self {
526        binary_expand(scope, self.into(), rhs.into(), Arithmetic::Min).into()
527    }
528    fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self {
529        binary_expand(scope, self.into(), rhs.into(), Arithmetic::Max).into()
530    }
531    fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
532        unary_expand(scope, self.into(), |op| {
533            Arithmetic::Clamp(ClampOperator {
534                input: op.input,
535                min_value: *min.expand,
536                max_value: *max.expand,
537            })
538        })
539        .into()
540    }
541}
542
543impl_binary_func!(
544    Powf,
545    powf,
546    Arithmetic::Powf,
547    f16,
548    bf16,
549    flex32,
550    tf32,
551    f32,
552    f64,
553    num_complex::Complex<f32>,
554    num_complex::Complex<f64>
555);
556
557impl_binary_func!(
558    Hypot,
559    hypot,
560    Arithmetic::Hypot,
561    f16,
562    bf16,
563    flex32,
564    tf32,
565    f32,
566    f64
567);
568
569impl_binary_func!(
570    Rhypot,
571    rhypot,
572    Arithmetic::Rhypot,
573    f16,
574    bf16,
575    flex32,
576    tf32,
577    f32,
578    f64
579);
580
581impl_binary_func!(
582    ArcTan2,
583    atan2,
584    Arithmetic::ArcTan2,
585    f16,
586    bf16,
587    flex32,
588    tf32,
589    f32,
590    f64
591);
592impl_binary_func!(
593    Remainder,
594    rem,
595    Arithmetic::Remainder,
596    e2m1,
597    e4m3,
598    e5m2,
599    ue8m0,
600    f16,
601    bf16,
602    flex32,
603    tf32,
604    f32,
605    f64,
606    i8,
607    i16,
608    i32,
609    i64,
610    u8,
611    u16,
612    u32,
613    u64,
614    usize,
615    isize
616);
617impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32, usize, isize);
618impl_binary_func!(
619    SaturatingAdd,
620    saturating_add,
621    Arithmetic::SaturatingAdd,
622    i8,
623    i16,
624    i32,
625    i64,
626    u8,
627    u16,
628    u32,
629    u64,
630    usize,
631    isize
632);
633impl_binary_func!(
634    SaturatingSub,
635    saturating_sub,
636    Arithmetic::SaturatingSub,
637    i8,
638    i16,
639    i32,
640    i64,
641    u8,
642    u16,
643    u32,
644    u64,
645    usize,
646    isize
647);
648impl_binary_func_scalar_out!(
649    Dot,
650    dot,
651    Arithmetic::Dot,
652    f16,
653    bf16,
654    flex32,
655    tf32,
656    f32,
657    f64,
658    i8,
659    i16,
660    i32,
661    i64,
662    u8,
663    u16,
664    u32,
665    u64,
666    usize,
667    isize
668);
669
670impl_binary_func_mixed_types!(
671    Powi,
672    powi,
673    i32,
674    Arithmetic::Powi,
675    f16,
676    bf16,
677    flex32,
678    tf32,
679    f32,
680    f64,
681    i8,
682    i16,
683    i32,
684    i64,
685    u8,
686    u16,
687    u32,
688    u64,
689    usize,
690    isize
691);