Skip to main content

cubecl_core/frontend/element/
complex.rs

1use core::ops::{Add, Div, Mul, Neg, Sub};
2
3use crate::{
4    Runtime,
5    ir::{ComplexKind, ElemType, ManagedVariable, Scope, StorageType, Type},
6    prelude::{CubePrimitive, CubeType, IntoRuntime, NativeAssign, NativeExpand, Scalar},
7    unexpanded,
8};
9use cubecl_ir::{Arithmetic, ConstantValue, Operator, features::ComplexUsage};
10use cubecl_runtime::client::ComputeClient;
11
12use crate::frontend::{
13    Abs,
14    operation::{unary_expand, unary_expand_fixed_output},
15};
16
17pub trait ComplexCore:
18    CubePrimitive
19    + Add<Output = Self>
20    + Sub<Output = Self>
21    + Mul<Output = Self>
22    + Div<Output = Self>
23    + Neg<Output = Self>
24    + Copy
25    + Clone
26    + PartialEq
27    + core::fmt::Debug
28    + Send
29    + Sync
30    + 'static
31{
32    type FloatElem: Scalar;
33
34    fn conj(self) -> Self {
35        unexpanded!()
36    }
37
38    fn real_val(self) -> Self::FloatElem {
39        unexpanded!()
40    }
41
42    fn imag_val(self) -> Self::FloatElem {
43        unexpanded!()
44    }
45
46    fn supported_complex_uses<R: Runtime>(
47        client: &ComputeClient<R>,
48    ) -> enumset::EnumSet<ComplexUsage> {
49        client
50            .properties()
51            .complex_usage(Self::as_type_native_unchecked().storage_type())
52    }
53}
54
55pub trait ComplexCompare: ComplexCore {}
56
57pub trait ComplexMath:
58    ComplexCore
59    + Abs<AbsElem = Self::FloatElem>
60    + crate::frontend::Exp
61    + crate::frontend::Log
62    + crate::frontend::Sin
63    + crate::frontend::Cos
64    + crate::frontend::Sqrt
65    + crate::frontend::Tanh
66    + crate::frontend::Powf
67{
68}
69
70pub trait ComplexCoreExpand {
71    fn __expand_conj_method(self, scope: &mut Scope) -> Self;
72    fn __expand_real_val_method(
73        self,
74        scope: &mut Scope,
75    ) -> NativeExpand<<Self as ComplexCoreExpand>::FloatElem>;
76    fn __expand_imag_val_method(
77        self,
78        scope: &mut Scope,
79    ) -> NativeExpand<<Self as ComplexCoreExpand>::FloatElem>;
80
81    type FloatElem: Scalar;
82}
83
84impl<T: ComplexCore> ComplexCoreExpand for NativeExpand<T> {
85    type FloatElem = T::FloatElem;
86
87    fn __expand_conj_method(self, scope: &mut Scope) -> Self {
88        unary_expand(scope, self.into(), Arithmetic::Conj).into()
89    }
90
91    fn __expand_real_val_method(self, scope: &mut Scope) -> NativeExpand<T::FloatElem> {
92        let expand_element: ManagedVariable = self.into();
93        let item = <T::FloatElem as CubePrimitive>::as_type(scope);
94        unary_expand_fixed_output(scope, expand_element, item, Operator::Real).into()
95    }
96
97    fn __expand_imag_val_method(self, scope: &mut Scope) -> NativeExpand<T::FloatElem> {
98        let expand_element: ManagedVariable = self.into();
99        let item = <T::FloatElem as CubePrimitive>::as_type(scope);
100        unary_expand_fixed_output(scope, expand_element, item, Operator::Imag).into()
101    }
102}
103
104macro_rules! impl_complex {
105    ($primitive:ty, $kind:ident, $float:ty) => {
106        impl CubeType for $primitive {
107            type ExpandType = NativeExpand<$primitive>;
108        }
109
110        impl CubePrimitive for $primitive {
111            type Scalar = Self;
112            type Size = crate::prelude::Const<1>;
113            type WithScalar<S: Scalar> = S;
114
115            fn as_type_native() -> Option<Type> {
116                Some(StorageType::Scalar(ElemType::Complex(ComplexKind::$kind)).into())
117            }
118
119            fn from_const_value(value: ConstantValue) -> Self {
120                let ConstantValue::Complex(re, im) = value else {
121                    unreachable!("expected Complex constant")
122                };
123                <$primitive>::new(re as $float, im as $float)
124            }
125        }
126
127        impl IntoRuntime for $primitive {
128            fn __expand_runtime_method(self, _scope: &mut Scope) -> NativeExpand<Self> {
129                self.into()
130            }
131        }
132
133        impl NativeAssign for $primitive {}
134
135        impl crate::prelude::IntoMut for $primitive {
136            fn into_mut(self, _scope: &mut Scope) -> Self {
137                self
138            }
139        }
140
141        impl Scalar for $primitive {}
142
143        impl Abs for $primitive {
144            type AbsElem = $float;
145        }
146
147        impl ComplexCore for $primitive {
148            type FloatElem = $float;
149        }
150
151        impl ComplexCompare for $primitive {}
152
153        impl ComplexMath for $primitive {}
154    };
155}
156
157impl_complex!(num_complex::Complex<f32>, C32, f32);
158impl_complex!(num_complex::Complex<f64>, C64, f64);