1use core::ops::Not;
2use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, ue8m0};
3use cubecl_ir::{Bitwise, Comparison, Operator};
4use half::{bf16, f16};
5
6use crate::{
7 flex32,
8 ir::{Arithmetic, ManagedVariable, Scope},
9 prelude::{CubePrimitive, CubePrimitiveExpand, CubeType, NativeExpand, Reinterpret, Scalar},
10 tf32, unexpanded,
11};
12
13use super::base::{unary_expand, unary_expand_fixed_output};
14
15pub mod not {
16 use super::*;
17
18 pub fn expand<T: CubeNot>(scope: &mut Scope, x: NativeExpand<T>) -> NativeExpand<T> {
19 if x.expand.ty.is_bool() {
20 unary_expand(scope, x.into(), Operator::Not).into()
21 } else {
22 unary_expand(scope, x.into(), Bitwise::BitwiseNot).into()
23 }
24 }
25}
26
27pub mod neg {
28 use super::*;
29
30 pub fn expand<E: CubePrimitive>(scope: &mut Scope, x: NativeExpand<E>) -> NativeExpand<E> {
31 unary_expand(scope, x.into(), Arithmetic::Neg).into()
32 }
33}
34
35macro_rules! impl_unary_func {
36 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
37 paste::paste! {
38 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
39 #[allow(unused_variables)]
40 fn $method_name(self) -> Self {
41 unexpanded!()
42 }
43
44 fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self> {
45 x.[<__expand_ $method_name _method>](scope)
46 }
47 }
48
49 pub trait [<$trait_name Expand>] {
50 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
51 }
52
53 $(impl $trait_name for $type {})*
54 impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
55 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
56 unary_expand(scope, self.into(), $operator).into()
57 }
58 }
59 }
60 }
61}
62
63impl Exp for f32 {
64 fn exp(self) -> Self {
65 self.exp()
66 }
67}
68
69pub trait Abs:
70 CubePrimitive
71 + CubeType<
72 ExpandType: AbsExpand<
73 AbsElem = Self::AbsElem,
74 AbsOut = NativeExpand<Self::WithScalar<Self::AbsElem>>,
75 >,
76 > + Sized
77{
78 type AbsElem: Scalar;
79
80 #[allow(unused_variables)]
81 fn abs(self) -> Self::WithScalar<Self::AbsElem> {
82 unexpanded!()
83 }
84
85 fn __expand_abs(
86 scope: &mut Scope,
87 x: NativeExpand<Self>,
88 ) -> NativeExpand<Self::WithScalar<Self::AbsElem>> {
89 x.__expand_abs_method(scope)
90 }
91}
92
93pub trait AbsExpand: CubePrimitiveExpand {
94 type AbsElem: Scalar;
95 type AbsOut;
96
97 fn __expand_abs_method(self, scope: &mut Scope) -> Self::AbsOut;
98}
99
100impl<T: Abs> AbsExpand for NativeExpand<T> {
101 type AbsElem = T::AbsElem;
102 type AbsOut = NativeExpand<T::WithScalar<T::AbsElem>>;
103
104 fn __expand_abs_method(self, scope: &mut Scope) -> Self::AbsOut {
105 let expand_element: ManagedVariable = self.into();
106 let item = <T::AbsElem as CubePrimitive>::as_type(scope)
107 .with_vector_size(expand_element.ty.vector_size());
108 unary_expand_fixed_output(scope, expand_element, item, Arithmetic::Abs).into()
109 }
110}
111
112macro_rules! impl_abs_same_type {
113 ($($type:ty),*) => {
114 $(impl Abs for $type {
115 type AbsElem = $type;
116 })*
117 };
118}
119
120macro_rules! impl_unary_func_scalar_out {
121 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
122 paste::paste! {
123 pub trait $trait_name: CubePrimitive
124 + CubeType<ExpandType: [<$trait_name Expand>]
125 + CubePrimitiveExpand<Scalar = NativeExpand<Self::Scalar>>>
126 + Sized {
127 #[allow(unused_variables)]
128 fn $method_name(self) -> Self {
129 unexpanded!()
130 }
131
132 fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self::Scalar> {
133 x.[<__expand_ $method_name _method>](scope)
134 }
135 }
136
137 pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
138 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::Scalar;
139 }
140
141 $(impl $trait_name for $type {})*
142 impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
143 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::Scalar {
144 let expand_element: ManagedVariable = self.into();
145 let item = expand_element.ty.with_vector_size(0);
146 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
147 }
148 }
149 }
150 }
151}
152
153macro_rules! impl_unary_func_fixed_out_ty {
154 ($trait_name:ident, $method_name:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
155 paste::paste! {
156 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]
157 + CubePrimitiveExpand<WithScalar<$out_ty> = NativeExpand<Self::WithScalar<$out_ty>>>> + Sized {
158 #[allow(unused_variables, clippy::wrong_self_convention)]
159 fn $method_name(self) -> Self::WithScalar<$out_ty> {
160 unexpanded!()
161 }
162
163 fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self::WithScalar<$out_ty>> {
164 x.[<__expand_ $method_name _method>](scope)
165 }
166 }
167
168 pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
169 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::WithScalar<$out_ty>;
170 }
171
172 $(impl $trait_name for $type {})*
173 impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
174 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::WithScalar<$out_ty> {
175 let expand_element: ManagedVariable = self.into();
176 let item = <$out_ty as CubePrimitive>::as_type(scope).with_vector_size(expand_element.ty.vector_size());
177 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
178 }
179 }
180 }
181 }
182}
183
184macro_rules! impl_not {
186 ($trait_name:ident, $method_name:ident, $($type:ty),*) => {
187 paste::paste! {
188 pub trait [<Cube $trait_name>]: $trait_name<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> {
189 fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self> {
190 x.[<__expand_ $method_name _method>](scope)
191 }
192 }
193
194 pub trait [<$trait_name Expand>] {
195 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
196 }
197
198 $(impl [<Cube $trait_name>] for $type {})*
199 impl<T: [<Cube $trait_name>] + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
200 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
201 not::expand(scope, self.into())
202 }
203 }
204 }
205 }
206}
207
208impl_not!(
209 Not, not, bool, u8, u16, u32, u64, i8, i16, i32, i64, isize, usize
210);
211
212impl_abs_same_type!(
213 e2m1, e4m3, e5m2, ue8m0, f16, bf16, flex32, tf32, f32, f64, i8, i16, i32, i64, u8, u16, u32,
214 u64, usize, isize
215);
216impl_unary_func!(
217 Exp,
218 exp,
219 Arithmetic::Exp,
220 f16,
221 bf16,
222 flex32,
223 tf32,
224 f64,
226 num_complex::Complex<f32>,
227 num_complex::Complex<f64>
228);
229impl_unary_func!(
230 Log,
231 ln,
232 Arithmetic::Log,
233 f16,
234 bf16,
235 flex32,
236 tf32,
237 f32,
238 f64,
239 num_complex::Complex<f32>,
240 num_complex::Complex<f64>
241);
242impl_unary_func!(
243 Log1p,
244 log1p,
245 Arithmetic::Log1p,
246 f16,
247 bf16,
248 flex32,
249 tf32,
250 f32,
251 f64
252);
253impl_unary_func!(
254 Expm1,
255 exp_m1,
256 Arithmetic::Expm1,
257 f16,
258 bf16,
259 flex32,
260 tf32,
261 f32,
262 f64
263);
264impl_unary_func!(
265 Cos,
266 cos,
267 Arithmetic::Cos,
268 f16,
269 bf16,
270 flex32,
271 tf32,
272 f32,
273 f64,
274 num_complex::Complex<f32>,
275 num_complex::Complex<f64>
276);
277impl_unary_func!(
278 Sin,
279 sin,
280 Arithmetic::Sin,
281 f16,
282 bf16,
283 flex32,
284 tf32,
285 f32,
286 f64,
287 num_complex::Complex<f32>,
288 num_complex::Complex<f64>
289);
290impl_unary_func!(Tan, tan, Arithmetic::Tan, f16, bf16, flex32, tf32, f32, f64);
291impl_unary_func!(
292 Tanh,
293 tanh,
294 Arithmetic::Tanh,
295 f16,
296 bf16,
297 flex32,
298 tf32,
299 f32,
300 f64,
301 num_complex::Complex<f32>,
302 num_complex::Complex<f64>
303);
304impl_unary_func!(
305 Sinh,
306 sinh,
307 Arithmetic::Sinh,
308 f16,
309 bf16,
310 flex32,
311 tf32,
312 f32,
313 f64
314);
315impl_unary_func!(
316 Cosh,
317 cosh,
318 Arithmetic::Cosh,
319 f16,
320 bf16,
321 flex32,
322 tf32,
323 f32,
324 f64
325);
326impl_unary_func!(
327 ArcCos,
328 acos,
329 Arithmetic::ArcCos,
330 f16,
331 bf16,
332 flex32,
333 tf32,
334 f32,
335 f64
336);
337impl_unary_func!(
338 ArcSin,
339 asin,
340 Arithmetic::ArcSin,
341 f16,
342 bf16,
343 flex32,
344 tf32,
345 f32,
346 f64
347);
348impl_unary_func!(
349 ArcTan,
350 atan,
351 Arithmetic::ArcTan,
352 f16,
353 bf16,
354 flex32,
355 tf32,
356 f32,
357 f64
358);
359impl_unary_func!(
360 ArcSinh,
361 asinh,
362 Arithmetic::ArcSinh,
363 f16,
364 bf16,
365 flex32,
366 tf32,
367 f32,
368 f64
369);
370impl_unary_func!(
371 ArcCosh,
372 acosh,
373 Arithmetic::ArcCosh,
374 f16,
375 bf16,
376 flex32,
377 tf32,
378 f32,
379 f64
380);
381impl_unary_func!(
382 ArcTanh,
383 atanh,
384 Arithmetic::ArcTanh,
385 f16,
386 bf16,
387 flex32,
388 tf32,
389 f32,
390 f64
391);
392impl_unary_func!(
393 Degrees,
394 to_degrees,
395 Arithmetic::Degrees,
396 f16,
397 bf16,
398 flex32,
399 tf32,
400 f32,
401 f64
402);
403impl_unary_func!(
404 Radians,
405 to_radians,
406 Arithmetic::Radians,
407 f16,
408 bf16,
409 flex32,
410 tf32,
411 f32,
412 f64
413);
414impl_unary_func!(
415 Sqrt,
416 sqrt,
417 Arithmetic::Sqrt,
418 f16,
419 bf16,
420 flex32,
421 tf32,
422 f32,
423 f64,
424 num_complex::Complex<f32>,
425 num_complex::Complex<f64>
426);
427impl_unary_func!(
428 InverseSqrt,
429 inverse_sqrt,
430 Arithmetic::InverseSqrt,
431 f16,
432 bf16,
433 flex32,
434 tf32,
435 f32,
436 f64
437);
438impl_unary_func!(
439 Round,
440 round,
441 Arithmetic::Round,
442 f16,
443 bf16,
444 flex32,
445 tf32,
446 f32,
447 f64
448);
449impl_unary_func!(
450 Floor,
451 floor,
452 Arithmetic::Floor,
453 f16,
454 bf16,
455 flex32,
456 tf32,
457 f32,
458 f64
459);
460impl_unary_func!(
461 Ceil,
462 ceil,
463 Arithmetic::Ceil,
464 f16,
465 bf16,
466 flex32,
467 tf32,
468 f32,
469 f64
470);
471impl_unary_func!(
472 Trunc,
473 trunc,
474 Arithmetic::Trunc,
475 f16,
476 bf16,
477 flex32,
478 tf32,
479 f32,
480 f64
481);
482impl_unary_func!(Erf, erf, Arithmetic::Erf, f16, bf16, flex32, tf32, f32, f64);
483impl_unary_func!(
484 Recip,
485 recip,
486 Arithmetic::Recip,
487 f16,
488 bf16,
489 flex32,
490 tf32,
491 f32,
492 f64
493);
494impl_unary_func_scalar_out!(
495 Magnitude,
496 magnitude,
497 Arithmetic::Magnitude,
498 f16,
499 bf16,
500 flex32,
501 tf32,
502 f32,
503 f64
504);
505impl_unary_func_scalar_out!(
506 VectorSum,
507 vector_sum,
508 Arithmetic::VectorSum,
509 e2m1,
510 e4m3,
511 e5m2,
512 ue8m0,
513 f16,
514 bf16,
515 flex32,
516 tf32,
517 f32,
518 f64,
519 i8,
520 i16,
521 i32,
522 i64,
523 u8,
524 u16,
525 u32,
526 u64,
527 usize,
528 isize
529);
530impl_unary_func!(
531 Normalize,
532 normalize,
533 Arithmetic::Normalize,
534 f16,
535 bf16,
536 flex32,
537 tf32,
538 f32,
539 f64
540);
541impl_unary_func_fixed_out_ty!(
542 CountOnes,
543 count_ones,
544 u32,
545 Bitwise::CountOnes,
546 u8,
547 i8,
548 u16,
549 i16,
550 u32,
551 i32,
552 u64,
553 i64,
554 usize,
555 isize
556);
557impl_unary_func!(
558 ReverseBits,
559 reverse_bits,
560 Bitwise::ReverseBits,
561 u8,
562 i8,
563 u16,
564 i16,
565 u32,
566 i32,
567 u64,
568 i64,
569 usize,
570 isize
571);
572
573impl_unary_func_fixed_out_ty!(
574 LeadingZeros,
575 leading_zeros,
576 u32,
577 Bitwise::LeadingZeros,
578 u8,
579 i8,
580 u16,
581 i16,
582 u32,
583 i32,
584 u64,
585 i64,
586 usize,
587 isize
588);
589impl_unary_func_fixed_out_ty!(
590 TrailingZeros,
591 trailing_zeros,
592 u32,
593 Bitwise::TrailingZeros,
594 u8,
595 i8,
596 u16,
597 i16,
598 u32,
599 i32,
600 u64,
601 i64,
602 usize,
603 isize
604);
605impl_unary_func_fixed_out_ty!(
606 FindFirstSet,
607 find_first_set,
608 u32,
609 Bitwise::FindFirstSet,
610 u8,
611 i8,
612 u16,
613 i16,
614 u32,
615 i32,
616 u64,
617 i64,
618 usize,
619 isize
620);
621impl_unary_func_fixed_out_ty!(
622 IsNan,
623 is_nan,
624 bool,
625 Comparison::IsNan,
626 f16,
627 bf16,
628 flex32,
629 tf32,
630 f32,
631 f64
632);
633impl_unary_func_fixed_out_ty!(
634 IsInf,
635 is_inf,
636 bool,
637 Comparison::IsInf,
638 f16,
639 bf16,
640 flex32,
641 tf32,
642 f32,
643 f64
644);
645
646pub trait FloatBits:
647 CubePrimitive + CubeType<ExpandType: FloatBitsExpand<Bits = Self::Bits>>
648{
649 type Bits: CubePrimitive;
650
651 fn __expand_from_bits(scope: &mut Scope, bits: NativeExpand<Self::Bits>) -> NativeExpand<Self> {
652 Self::__expand_reinterpret(scope, bits)
653 }
654
655 fn __expand_to_bits(scope: &mut Scope, this: NativeExpand<Self>) -> NativeExpand<Self::Bits> {
656 <Self::Bits as Reinterpret>::__expand_reinterpret(scope, this)
657 }
658}
659
660pub trait FloatBitsExpand: Sized {
661 type Bits: CubePrimitive;
662
663 fn __expand_to_bits_method(self, scope: &mut Scope) -> NativeExpand<Self::Bits>;
664}
665
666impl<F: FloatBits> FloatBitsExpand for NativeExpand<F> {
667 type Bits = F::Bits;
668
669 fn __expand_to_bits_method(self, scope: &mut Scope) -> NativeExpand<Self::Bits> {
670 <Self::Bits as Reinterpret>::__expand_reinterpret(scope, self)
671 }
672}
673
674impl FloatBits for e2m1x2 {
675 type Bits = u8;
676}
677
678impl FloatBits for e5m2 {
679 type Bits = u8;
680}
681
682impl FloatBits for e4m3 {
683 type Bits = u8;
684}
685
686impl FloatBits for f16 {
687 type Bits = u16;
688}
689
690impl FloatBits for bf16 {
691 type Bits = u16;
692}
693
694impl FloatBits for f32 {
695 type Bits = u32;
696}
697
698impl FloatBits for f64 {
699 type Bits = u64;
700}