cubecl_core/frontend/element/
complex.rs1use core::ops::{Add, Div, Mul, Neg, Sub};
2
3use crate::{
4 Runtime,
5 ir::{ComplexKind, ElemType, ManagedVariable, Scope, StorageType, Type},
6 prelude::{CubePrimitive, CubeType, IntoRuntime, NativeAssign, NativeExpand, Scalar},
7 unexpanded,
8};
9use cubecl_ir::{Arithmetic, ConstantValue, Operator, features::ComplexUsage};
10use cubecl_runtime::client::ComputeClient;
11
12use crate::frontend::{
13 Abs,
14 operation::{unary_expand, unary_expand_fixed_output},
15};
16
17pub trait ComplexCore:
18 CubePrimitive
19 + Add<Output = Self>
20 + Sub<Output = Self>
21 + Mul<Output = Self>
22 + Div<Output = Self>
23 + Neg<Output = Self>
24 + Copy
25 + Clone
26 + PartialEq
27 + core::fmt::Debug
28 + Send
29 + Sync
30 + 'static
31{
32 type FloatElem: Scalar;
33
34 fn conj(self) -> Self {
35 unexpanded!()
36 }
37
38 fn real_val(self) -> Self::FloatElem {
39 unexpanded!()
40 }
41
42 fn imag_val(self) -> Self::FloatElem {
43 unexpanded!()
44 }
45
46 fn supported_complex_uses<R: Runtime>(
47 client: &ComputeClient<R>,
48 ) -> enumset::EnumSet<ComplexUsage> {
49 client
50 .properties()
51 .complex_usage(Self::as_type_native_unchecked().storage_type())
52 }
53}
54
55pub trait ComplexCompare: ComplexCore {}
56
57pub trait ComplexMath:
58 ComplexCore
59 + Abs<AbsElem = Self::FloatElem>
60 + crate::frontend::Exp
61 + crate::frontend::Log
62 + crate::frontend::Sin
63 + crate::frontend::Cos
64 + crate::frontend::Sqrt
65 + crate::frontend::Tanh
66 + crate::frontend::Powf
67{
68}
69
70pub trait ComplexCoreExpand {
71 fn __expand_conj_method(self, scope: &mut Scope) -> Self;
72 fn __expand_real_val_method(
73 self,
74 scope: &mut Scope,
75 ) -> NativeExpand<<Self as ComplexCoreExpand>::FloatElem>;
76 fn __expand_imag_val_method(
77 self,
78 scope: &mut Scope,
79 ) -> NativeExpand<<Self as ComplexCoreExpand>::FloatElem>;
80
81 type FloatElem: Scalar;
82}
83
84impl<T: ComplexCore> ComplexCoreExpand for NativeExpand<T> {
85 type FloatElem = T::FloatElem;
86
87 fn __expand_conj_method(self, scope: &mut Scope) -> Self {
88 unary_expand(scope, self.into(), Arithmetic::Conj).into()
89 }
90
91 fn __expand_real_val_method(self, scope: &mut Scope) -> NativeExpand<T::FloatElem> {
92 let expand_element: ManagedVariable = self.into();
93 let item = <T::FloatElem as CubePrimitive>::as_type(scope);
94 unary_expand_fixed_output(scope, expand_element, item, Operator::Real).into()
95 }
96
97 fn __expand_imag_val_method(self, scope: &mut Scope) -> NativeExpand<T::FloatElem> {
98 let expand_element: ManagedVariable = self.into();
99 let item = <T::FloatElem as CubePrimitive>::as_type(scope);
100 unary_expand_fixed_output(scope, expand_element, item, Operator::Imag).into()
101 }
102}
103
104macro_rules! impl_complex {
105 ($primitive:ty, $kind:ident, $float:ty) => {
106 impl CubeType for $primitive {
107 type ExpandType = NativeExpand<$primitive>;
108 }
109
110 impl CubePrimitive for $primitive {
111 type Scalar = Self;
112 type Size = crate::prelude::Const<1>;
113 type WithScalar<S: Scalar> = S;
114
115 fn as_type_native() -> Option<Type> {
116 Some(StorageType::Scalar(ElemType::Complex(ComplexKind::$kind)).into())
117 }
118
119 fn from_const_value(value: ConstantValue) -> Self {
120 let ConstantValue::Complex(re, im) = value else {
121 unreachable!("expected Complex constant")
122 };
123 <$primitive>::new(re as $float, im as $float)
124 }
125 }
126
127 impl IntoRuntime for $primitive {
128 fn __expand_runtime_method(self, _scope: &mut Scope) -> NativeExpand<Self> {
129 self.into()
130 }
131 }
132
133 impl NativeAssign for $primitive {}
134
135 impl crate::prelude::IntoMut for $primitive {
136 fn into_mut(self, _scope: &mut Scope) -> Self {
137 self
138 }
139 }
140
141 impl Scalar for $primitive {}
142
143 impl Abs for $primitive {
144 type AbsElem = $float;
145 }
146
147 impl ComplexCore for $primitive {
148 type FloatElem = $float;
149 }
150
151 impl ComplexCompare for $primitive {}
152
153 impl ComplexMath for $primitive {}
154 };
155}
156
157impl_complex!(num_complex::Complex<f32>, C32, f32);
158impl_complex!(num_complex::Complex<f64>, C64, f64);