1use crate as cubecl;
2use crate::ir::{Arithmetic, Bitwise, ManagedVariable, Operator, Scope};
3use crate::{
4 flex32,
5 frontend::{CubePrimitive, NativeExpand},
6 prelude::*,
7};
8use crate::{frontend::CubeType, tf32};
9use crate::{
10 frontend::operation::base::{binary_expand, binary_expand_fixed_output},
11 unexpanded,
12};
13use core::{cmp::Ordering, ops::*};
14use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
15use cubecl_ir::ClampOperator;
16use cubecl_macros::derive_expand;
17use half::{bf16, f16};
18
19pub mod add {
20 use super::*;
21
22 pub fn expand<C: CubePrimitive>(
23 scope: &mut Scope,
24 lhs: NativeExpand<C>,
25 rhs: NativeExpand<C>,
26 ) -> NativeExpand<C> {
27 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
28 }
29}
30
31pub mod sub {
32 use cubecl_ir::{ConstantValue, Variable};
33
34 use super::*;
35
36 pub fn expand<C: CubePrimitive>(
37 scope: &mut Scope,
38 lhs: NativeExpand<C>,
39 rhs: NativeExpand<C>,
40 ) -> NativeExpand<C> {
41 match (lhs.expand.as_const(), rhs.expand.as_const()) {
43 (Some(ConstantValue::UInt(lhs_val)), Some(ConstantValue::UInt(rhs_val))) => {
44 let item_lhs = lhs.expand.ty;
45 let item_rhs = rhs.expand.ty;
46
47 let vector_size = find_vectorization(item_lhs, item_rhs);
48
49 let item = item_lhs.with_vector_size(vector_size);
50 let value = (lhs_val - rhs_val).into();
51 ManagedVariable::Plain(Variable::constant(value, item)).into()
52 }
53 _ => binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub).into(),
54 }
55 }
56}
57
58pub mod mul {
59 use super::*;
60
61 pub fn expand<C: CubePrimitive>(
62 scope: &mut Scope,
63 lhs: NativeExpand<C>,
64 rhs: NativeExpand<C>,
65 ) -> NativeExpand<C> {
66 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul).into()
67 }
68}
69
70pub mod div {
71 use super::*;
72
73 pub fn expand<C: CubePrimitive>(
74 scope: &mut Scope,
75 lhs: NativeExpand<C>,
76 rhs: NativeExpand<C>,
77 ) -> NativeExpand<C> {
78 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div).into()
79 }
80}
81
82pub mod rem {
83 use super::*;
84
85 pub fn expand<C: CubePrimitive>(
86 scope: &mut Scope,
87 lhs: NativeExpand<C>,
88 rhs: NativeExpand<C>,
89 ) -> NativeExpand<C> {
90 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo).into()
91 }
92}
93
94pub mod and {
95 use super::*;
96
97 pub fn expand<C: CubePrimitive>(
98 scope: &mut Scope,
99 lhs: NativeExpand<C>,
100 rhs: NativeExpand<C>,
101 ) -> NativeExpand<bool> {
102 binary_expand(scope, lhs.into(), rhs.into(), Operator::And).into()
103 }
104}
105
106pub mod bitand {
107 use super::*;
108
109 pub fn expand<C: CubePrimitive>(
110 scope: &mut Scope,
111 lhs: NativeExpand<C>,
112 rhs: NativeExpand<C>,
113 ) -> NativeExpand<C> {
114 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd).into()
115 }
116}
117
118pub mod bitor {
119 use super::*;
120
121 pub fn expand<C: CubePrimitive>(
122 scope: &mut Scope,
123 lhs: NativeExpand<C>,
124 rhs: NativeExpand<C>,
125 ) -> NativeExpand<C> {
126 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr).into()
127 }
128}
129
130pub mod or {
131 use super::*;
132
133 pub fn expand<C: CubePrimitive>(
134 scope: &mut Scope,
135 lhs: NativeExpand<C>,
136 rhs: NativeExpand<C>,
137 ) -> NativeExpand<bool> {
138 binary_expand(scope, lhs.into(), rhs.into(), Operator::Or).into()
139 }
140}
141
142pub mod bitxor {
143 use super::*;
144
145 pub fn expand<C: CubePrimitive>(
146 scope: &mut Scope,
147 lhs: NativeExpand<C>,
148 rhs: NativeExpand<C>,
149 ) -> NativeExpand<C> {
150 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor).into()
151 }
152}
153
154pub mod shl {
155 use super::*;
156
157 pub fn expand<C: CubePrimitive>(
158 scope: &mut Scope,
159 lhs: NativeExpand<C>,
160 rhs: NativeExpand<C>,
161 ) -> NativeExpand<C> {
162 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft).into()
163 }
164}
165
166pub mod shr {
167 use super::*;
168
169 pub fn expand<C: CubePrimitive>(
170 scope: &mut Scope,
171 lhs: NativeExpand<C>,
172 rhs: NativeExpand<C>,
173 ) -> NativeExpand<C> {
174 binary_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight).into()
175 }
176}
177
178pub mod clamp {
179 use super::*;
180
181 pub fn expand<C: PartialOrd + CubePrimitive>(
182 scope: &mut Scope,
183 input: NativeExpand<C>,
184 min: NativeExpand<C>,
185 max: NativeExpand<C>,
186 ) -> NativeExpand<C> {
187 unary_expand(scope, input.into(), |op| {
188 Arithmetic::Clamp(ClampOperator {
189 input: op.input,
190 min_value: *min.expand,
191 max_value: *max.expand,
192 })
193 })
194 .into()
195 }
196}
197
198pub mod clamp_max {
199 use super::*;
200
201 pub fn expand<C: PartialOrd + CubePrimitive>(
202 scope: &mut Scope,
203 lhs: NativeExpand<C>,
204 rhs: NativeExpand<C>,
205 ) -> NativeExpand<C> {
206 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
207 }
208}
209
210pub mod clamp_min {
211 use super::*;
212
213 pub fn expand<C: PartialOrd + CubePrimitive>(
214 scope: &mut Scope,
215 lhs: NativeExpand<C>,
216 rhs: NativeExpand<C>,
217 ) -> NativeExpand<C> {
218 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
219 }
220}
221
222pub fn min<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
225 clamp_max(lhs, rhs)
226}
227
228pub mod min {
229 use super::*;
230
231 pub fn expand<C: PartialOrd + CubePrimitive>(
232 scope: &mut Scope,
233 lhs: NativeExpand<C>,
234 rhs: NativeExpand<C>,
235 ) -> NativeExpand<C> {
236 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Min).into()
237 }
238}
239
240pub fn max<T: PartialOrd + CubePrimitive>(lhs: T, rhs: T) -> T {
243 clamp_min(lhs, rhs)
244}
245
246pub mod max {
247 use super::*;
248
249 pub fn expand<C: PartialOrd + CubePrimitive>(
250 scope: &mut Scope,
251 lhs: NativeExpand<C>,
252 rhs: NativeExpand<C>,
253 ) -> NativeExpand<C> {
254 binary_expand(scope, lhs.into(), rhs.into(), Arithmetic::Max).into()
255 }
256}
257
258macro_rules! impl_binary_func {
260 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
261 paste::paste! {
262 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
263 fn $method_name(self, _rhs: Self) -> Self {
264 unexpanded!()
265 }
266
267 fn [<__expand_ $method_name>](
268 scope: &mut Scope,
269 lhs: NativeExpand<Self>,
270 rhs: NativeExpand<Self>,
271 ) -> NativeExpand<Self> {
272 lhs.[<__expand_ $method_name _method>](scope, rhs)
273 }
274 }
275
276 pub trait [<$trait_name Expand>] {
277 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self;
278 }
279
280 $(impl $trait_name for $type {})*
281 impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for NativeExpand<T> {
282 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self {
283 binary_expand(scope, self.into(), rhs.into(), $operator).into()
284 }
285 }
286 }
287 }
288}
289
290macro_rules! impl_binary_func_scalar_out {
291 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
292 paste::paste! {
293 pub trait $trait_name: CubePrimitive
294 + CubeType<ExpandType: [<$trait_name Expand>]
295 + CubePrimitiveExpand<Scalar = NativeExpand<Self::Scalar>>>
296 + Sized {
297 fn $method_name(self, _rhs: Self) -> Self::Scalar {
298 unexpanded!()
299 }
300
301 fn [<__expand_ $method_name>](
302 scope: &mut Scope,
303 lhs: NativeExpand<Self>,
304 rhs: NativeExpand<Self>,
305 ) -> NativeExpand<Self::Scalar> {
306 lhs.[<__expand_ $method_name _method>](scope, rhs)
307 }
308 }
309
310 pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
311 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self::Scalar;
312 }
313
314 $(impl $trait_name for $type {})*
315 impl<T: CubePrimitive + $trait_name> [<$trait_name Expand>] for NativeExpand<T> {
316 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Self) -> Self::Scalar {
317 let lhs: ManagedVariable = self.into();
318 let item = lhs.ty.with_vector_size(0);
319 binary_expand_fixed_output(scope, lhs, rhs.into(), item, $operator).into()
320 }
321 }
322 }
323 }
324}
325
326macro_rules! impl_binary_func_mixed_types {
327 ($trait_name:ident, $method_name:ident, $rhs_ty: ident, $operator:expr, $($type:ty),*) => {
328 paste::paste! {
329 pub trait $trait_name<Rhs: CubePrimitive + CubeType<ExpandType: Into<ManagedVariable>> + Sized>:
330 CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]<Rhs>> + Sized {
331 fn $method_name(self, _rhs: Rhs) -> Self {
332 unexpanded!()
333 }
334
335 fn [<__expand_ $method_name>](
336 scope: &mut Scope,
337 lhs: NativeExpand<Self>,
338 rhs: NativeExpand<Rhs>,
339 ) -> NativeExpand<Self> {
340 binary_expand(scope, lhs.into(), rhs.into(), $operator).into()
341 }
342 }
343
344 pub trait [<$trait_name Expand>]<Rhs: CubeType>{
345 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: Rhs::ExpandType) -> Self;
346 }
347
348 $(impl $trait_name<$rhs_ty> for $type {})*
349 impl<Rhs: CubePrimitive, T: CubePrimitive + $trait_name<Rhs>> [<$trait_name Expand>]<Rhs> for NativeExpand<T> {
350 fn [<__expand_ $method_name _method>](self, scope: &mut Scope, rhs: NativeExpand<Rhs>) -> Self {
351 binary_expand(scope, self.into(), rhs.into(), $operator).into()
352 }
353 }
354 }
355 }
356}
357
358macro_rules! impl_core_binop {
359 ($trait: ident, $method: ident, $op: expr) => {
360 paste::paste! {
361 pub trait [<Cube $trait>]: $trait<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
362 fn [<__expand_ $method>](
363 scope: &mut Scope,
364 lhs: NativeExpand<Self>,
365 rhs: NativeExpand<Self>,
366 ) -> NativeExpand<Self> {
367 lhs.[<__expand_ $method _method>](scope, rhs)
368 }
369 }
370
371 pub trait [<$trait Expand>] {
372 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self;
373 }
374
375 impl<T: $trait<Output = T> + CubePrimitive> [<Cube $trait>] for T {}
376 impl<T: $trait<Output = T> + CubePrimitive> [<$trait Expand>] for NativeExpand<T> {
377 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self {
378 binary_expand(scope, self.into(), rhs.into(), $op).into()
379 }
380 }
381 }
382 };
383}
384
385macro_rules! impl_core_assign_binop {
386 ($trait: ident, $method: ident, $op: expr) => {
387 paste::paste! {
388 pub trait [<Cube $trait>]: $trait + CubePrimitive + CubeType<ExpandType: [<$trait Expand>]> + Sized {
389 fn [<__expand_ $method>](
390 scope: &mut Scope,
391 lhs: NativeExpand<Self>,
392 rhs: NativeExpand<Self>,
393 ) {
394 lhs.[<__expand_ $method _method>](scope, rhs)
395 }
396 }
397
398 pub trait [<$trait Expand>] {
399 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self);
400 }
401
402 impl<T: $trait + CubePrimitive> [<Cube $trait>] for T {}
403 impl<T: $trait + CubePrimitive> [<$trait Expand>] for NativeExpand<T> {
404 fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) {
405 assign_op_expand(scope, self.into(), rhs.into(), $op);
406 }
407 }
408 }
409 };
410}
411
412impl_core_binop!(Add, add, Arithmetic::Add);
413impl_core_binop!(Sub, sub, Arithmetic::Sub);
414impl_core_binop!(Mul, mul, Arithmetic::Mul);
415impl_core_binop!(Div, mul, Arithmetic::Div);
416impl_core_binop!(Rem, rem, Arithmetic::Modulo);
417
418impl_core_assign_binop!(AddAssign, add_assign, Arithmetic::Add);
419impl_core_assign_binop!(SubAssign, sub_assign, Arithmetic::Sub);
420impl_core_assign_binop!(MulAssign, mul_assign, Arithmetic::Mul);
421impl_core_assign_binop!(DivAssign, div_assign, Arithmetic::Div);
422impl_core_assign_binop!(RemAssign, rem_assign, Arithmetic::Modulo);
423
424#[derive_expand(CubeType, CubeTypeMut, IntoRuntime)]
425#[cube(runtime_variants, no_constructors)]
426pub enum Ordering {
427 Less = -1,
428 Equal = 0,
429 Greater = 1,
430}
431
432fn ordering_disc(name: &'static str) -> NativeExpand<i32> {
433 OrderingExpand::discriminant_of(name).into()
434}
435
436#[allow(non_snake_case)]
437pub trait CubeOrdering {
438 fn Less() -> Ordering {
439 Ordering::Less
440 }
441 fn Equal() -> Ordering {
442 Ordering::Equal
443 }
444 fn Greater() -> Ordering {
445 Ordering::Greater
446 }
447 fn __expand_Less(_scope: &mut Scope) -> OrderingExpand {
448 OrderingExpand {
449 discriminant: ordering_disc("Less"),
450 value: (),
451 }
452 }
453 fn __expand_Equal(_scope: &mut Scope) -> OrderingExpand {
454 OrderingExpand {
455 discriminant: ordering_disc("Equal"),
456 value: (),
457 }
458 }
459 fn __expand_Greater(_scope: &mut Scope) -> OrderingExpand {
460 OrderingExpand {
461 discriminant: ordering_disc("Greater"),
462 value: (),
463 }
464 }
465}
466
467impl CubeOrdering for Ordering {}
468
469pub trait CubeOrd: Ord + CubeType<ExpandType: OrdExpand> + Sized {
470 fn __expand_cmp(
471 scope: &mut Scope,
472 lhs: Self::ExpandType,
473 rhs: Self::ExpandType,
474 ) -> OrderingExpand {
475 lhs.__expand_cmp_method(scope, rhs)
476 }
477
478 fn __expand_min(
479 scope: &mut Scope,
480 lhs: Self::ExpandType,
481 rhs: Self::ExpandType,
482 ) -> Self::ExpandType {
483 lhs.__expand_min_method(scope, rhs)
484 }
485
486 fn __expand_max(
487 scope: &mut Scope,
488 lhs: Self::ExpandType,
489 rhs: Self::ExpandType,
490 ) -> Self::ExpandType {
491 lhs.__expand_max_method(scope, rhs)
492 }
493
494 fn __expand_clamp(
495 scope: &mut Scope,
496 lhs: Self::ExpandType,
497 min: Self::ExpandType,
498 max: Self::ExpandType,
499 ) -> Self::ExpandType {
500 lhs.__expand_clamp_method(scope, min, max)
501 }
502}
503pub trait OrdExpand {
504 fn __expand_cmp_method(self, scope: &mut Scope, rhs: Self) -> OrderingExpand;
505 fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self;
506 fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self;
507 fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self;
508}
509
510impl<T: Ord + CubePrimitive> CubeOrd for T {}
511impl<T: Ord + CubePrimitive> OrdExpand for NativeExpand<T> {
512 fn __expand_cmp_method(self, scope: &mut Scope, rhs: Self) -> OrderingExpand {
513 let lhs_lt_rhs = lt::expand(scope, self.clone(), rhs.clone());
514 let lhs_gt_rhs = gt::expand(scope, self, rhs);
515 let less = ordering_disc("Less");
516 let equal = ordering_disc("Equal");
517 let greater = ordering_disc("Greater");
518 let eq_or_gt = select::expand(scope, lhs_gt_rhs, greater, equal);
519 let discriminant = select::expand(scope, lhs_lt_rhs, less, eq_or_gt);
520 OrderingExpand {
521 discriminant,
522 value: (),
523 }
524 }
525 fn __expand_min_method(self, scope: &mut Scope, rhs: Self) -> Self {
526 binary_expand(scope, self.into(), rhs.into(), Arithmetic::Min).into()
527 }
528 fn __expand_max_method(self, scope: &mut Scope, rhs: Self) -> Self {
529 binary_expand(scope, self.into(), rhs.into(), Arithmetic::Max).into()
530 }
531 fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
532 unary_expand(scope, self.into(), |op| {
533 Arithmetic::Clamp(ClampOperator {
534 input: op.input,
535 min_value: *min.expand,
536 max_value: *max.expand,
537 })
538 })
539 .into()
540 }
541}
542
543impl_binary_func!(
544 Powf,
545 powf,
546 Arithmetic::Powf,
547 f16,
548 bf16,
549 flex32,
550 tf32,
551 f32,
552 f64,
553 num_complex::Complex<f32>,
554 num_complex::Complex<f64>
555);
556
557impl_binary_func!(
558 Hypot,
559 hypot,
560 Arithmetic::Hypot,
561 f16,
562 bf16,
563 flex32,
564 tf32,
565 f32,
566 f64
567);
568
569impl_binary_func!(
570 Rhypot,
571 rhypot,
572 Arithmetic::Rhypot,
573 f16,
574 bf16,
575 flex32,
576 tf32,
577 f32,
578 f64
579);
580
581impl_binary_func!(
582 ArcTan2,
583 atan2,
584 Arithmetic::ArcTan2,
585 f16,
586 bf16,
587 flex32,
588 tf32,
589 f32,
590 f64
591);
592impl_binary_func!(
593 Remainder,
594 rem,
595 Arithmetic::Remainder,
596 e2m1,
597 e4m3,
598 e5m2,
599 ue8m0,
600 f16,
601 bf16,
602 flex32,
603 tf32,
604 f32,
605 f64,
606 i8,
607 i16,
608 i32,
609 i64,
610 u8,
611 u16,
612 u32,
613 u64,
614 usize,
615 isize
616);
617impl_binary_func!(MulHi, mul_hi, Arithmetic::MulHi, i32, u32, usize, isize);
618impl_binary_func!(
619 SaturatingAdd,
620 saturating_add,
621 Arithmetic::SaturatingAdd,
622 i8,
623 i16,
624 i32,
625 i64,
626 u8,
627 u16,
628 u32,
629 u64,
630 usize,
631 isize
632);
633impl_binary_func!(
634 SaturatingSub,
635 saturating_sub,
636 Arithmetic::SaturatingSub,
637 i8,
638 i16,
639 i32,
640 i64,
641 u8,
642 u16,
643 u32,
644 u64,
645 usize,
646 isize
647);
648impl_binary_func_scalar_out!(
649 Dot,
650 dot,
651 Arithmetic::Dot,
652 f16,
653 bf16,
654 flex32,
655 tf32,
656 f32,
657 f64,
658 i8,
659 i16,
660 i32,
661 i64,
662 u8,
663 u16,
664 u32,
665 u64,
666 usize,
667 isize
668);
669
670impl_binary_func_mixed_types!(
671 Powi,
672 powi,
673 i32,
674 Arithmetic::Powi,
675 f16,
676 bf16,
677 flex32,
678 tf32,
679 f32,
680 f64,
681 i8,
682 i16,
683 i32,
684 i64,
685 u8,
686 u16,
687 u32,
688 u64,
689 usize,
690 isize
691);