1use crate as cubecl;
2use alloc::{
3 format,
4 string::{String, ToString},
5 vec::Vec,
6};
7use cubecl::prelude::*;
8use cubecl_ir::{
9 Arithmetic, Bitwise, Comparison, Operation, Operator, Scope, StorageType, Type,
10 features::ComplexUsage,
11};
12use cubecl_macros::intrinsic;
13
14#[cube]
15#[allow(unused_variables)]
16pub fn push_validation_error(#[comptime] msg: String) {
22 intrinsic! {|scope| scope.push_error(msg)}
23}
24
25fn collect_complex_storage_types(types: impl IntoIterator<Item = Type>) -> Vec<StorageType> {
26 let mut storages = Vec::new();
27
28 for ty in types {
29 if ty.is_semantic() {
30 continue;
31 }
32
33 let storage = ty.storage_type();
34 if storage.elem_type().is_complex() && !storages.contains(&storage) {
35 storages.push(storage);
36 }
37 }
38
39 storages
40}
41
42fn require_complex_usage(
43 scope: &mut Scope,
44 types: impl IntoIterator<Item = Type>,
45 usage: ComplexUsage,
46 op_name: &'static str,
47) {
48 let storages = collect_complex_storage_types(types);
49 if storages.is_empty() {
50 return;
51 }
52
53 let Some(properties) = scope.properties.clone() else {
59 return;
60 };
61
62 for storage in storages {
63 if !properties.supports_complex_usage(storage, usage) {
64 scope.push_error(format!(
65 "Complex operation `{op_name}` requires {usage:?} support for `{storage}`, but the active runtime does not advertise it."
66 ));
67 }
68 }
69}
70
71fn reject_complex_operation(
72 scope: &mut Scope,
73 types: impl IntoIterator<Item = Type>,
74 op_name: &'static str,
75) {
76 let storages = collect_complex_storage_types(types);
77 if storages.is_empty() {
78 return;
79 }
80
81 let supported = storages
82 .into_iter()
83 .map(|storage| storage.to_string())
84 .collect::<Vec<_>>()
85 .join(", ");
86 scope.push_error(format!(
87 "Complex operation `{op_name}` is not part of the v1 complex contract for `{supported}`."
88 ));
89}
90
91pub(crate) fn validate_complex_operation(scope: &mut Scope, operation: &Operation) {
92 match operation {
93 Operation::Arithmetic(arithmetic) => match arithmetic {
94 Arithmetic::Add(op) => {
95 require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Core, "+")
96 }
97 Arithmetic::Sub(op) => {
98 require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Core, "-")
99 }
100 Arithmetic::Mul(op) => {
101 require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Core, "*")
102 }
103 Arithmetic::Div(op) => {
104 require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Core, "/")
105 }
106 Arithmetic::Neg(op) => {
107 require_complex_usage(scope, [op.input.ty], ComplexUsage::Core, "neg")
108 }
109 Arithmetic::Conj(op) => {
110 require_complex_usage(scope, [op.input.ty], ComplexUsage::Core, "conj")
111 }
112 Arithmetic::Abs(op) => {
113 require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "abs")
114 }
115 Arithmetic::Exp(op) => {
116 require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "exp")
117 }
118 Arithmetic::Log(op) => {
119 require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "log")
120 }
121 Arithmetic::Sin(op) => {
122 require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "sin")
123 }
124 Arithmetic::Cos(op) => {
125 require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "cos")
126 }
127 Arithmetic::Sqrt(op) => {
128 require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "sqrt")
129 }
130 Arithmetic::Tanh(op) => {
131 require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "tanh")
132 }
133 Arithmetic::Powf(op) => {
134 require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Math, "powf")
135 }
136 Arithmetic::Fma(op) => {
137 reject_complex_operation(scope, [op.a.ty, op.b.ty, op.c.ty], "fma")
138 }
139 Arithmetic::SaturatingAdd(op) => {
140 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "saturating_add")
141 }
142 Arithmetic::SaturatingSub(op) => {
143 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "saturating_sub")
144 }
145 Arithmetic::Log1p(op) => reject_complex_operation(scope, [op.input.ty], "log1p"),
146 Arithmetic::Expm1(op) => reject_complex_operation(scope, [op.input.ty], "exp_m1"),
147 Arithmetic::Tan(op) => reject_complex_operation(scope, [op.input.ty], "tan"),
148 Arithmetic::Sinh(op) => reject_complex_operation(scope, [op.input.ty], "sinh"),
149 Arithmetic::Cosh(op) => reject_complex_operation(scope, [op.input.ty], "cosh"),
150 Arithmetic::ArcCos(op) => reject_complex_operation(scope, [op.input.ty], "acos"),
151 Arithmetic::ArcSin(op) => reject_complex_operation(scope, [op.input.ty], "asin"),
152 Arithmetic::ArcTan(op) => reject_complex_operation(scope, [op.input.ty], "atan"),
153 Arithmetic::ArcSinh(op) => reject_complex_operation(scope, [op.input.ty], "asinh"),
154 Arithmetic::ArcCosh(op) => reject_complex_operation(scope, [op.input.ty], "acosh"),
155 Arithmetic::ArcTanh(op) => reject_complex_operation(scope, [op.input.ty], "atanh"),
156 Arithmetic::Degrees(op) => reject_complex_operation(scope, [op.input.ty], "to_degrees"),
157 Arithmetic::Radians(op) => reject_complex_operation(scope, [op.input.ty], "to_radians"),
158 Arithmetic::ArcTan2(op) => {
159 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "atan2")
160 }
161 Arithmetic::Powi(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "powi"),
162 Arithmetic::Hypot(op) => {
163 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "hypot")
164 }
165 Arithmetic::Rhypot(op) => {
166 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "rhypot")
167 }
168 Arithmetic::InverseSqrt(op) => {
169 reject_complex_operation(scope, [op.input.ty], "inverse_sqrt")
170 }
171 Arithmetic::Round(op) => reject_complex_operation(scope, [op.input.ty], "round"),
172 Arithmetic::Floor(op) => reject_complex_operation(scope, [op.input.ty], "floor"),
173 Arithmetic::Ceil(op) => reject_complex_operation(scope, [op.input.ty], "ceil"),
174 Arithmetic::Trunc(op) => reject_complex_operation(scope, [op.input.ty], "trunc"),
175 Arithmetic::Erf(op) => reject_complex_operation(scope, [op.input.ty], "erf"),
176 Arithmetic::Recip(op) => reject_complex_operation(scope, [op.input.ty], "recip"),
177 Arithmetic::Clamp(op) => reject_complex_operation(
178 scope,
179 [op.input.ty, op.min_value.ty, op.max_value.ty],
180 "clamp",
181 ),
182 Arithmetic::Modulo(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "%"),
183 Arithmetic::Max(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "max"),
184 Arithmetic::Min(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "min"),
185 Arithmetic::Remainder(op) => {
186 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "remainder")
187 }
188 Arithmetic::Magnitude(op) => {
189 reject_complex_operation(scope, [op.input.ty], "magnitude")
190 }
191 Arithmetic::Normalize(op) => {
192 reject_complex_operation(scope, [op.input.ty], "normalize")
193 }
194 Arithmetic::Dot(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "dot"),
195 Arithmetic::MulHi(op) => {
196 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "mul_hi")
197 }
198 Arithmetic::VectorSum(op) => {
199 reject_complex_operation(scope, [op.input.ty], "vector_sum")
200 }
201 },
202 Operation::Comparison(comparison) => match comparison {
203 Comparison::Equal(op) => {
204 require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Compare, "==")
205 }
206 Comparison::NotEqual(op) => {
207 require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Compare, "!=")
208 }
209 Comparison::Lower(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "<"),
210 Comparison::LowerEqual(op) => {
211 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "<=")
212 }
213 Comparison::GreaterEqual(op) => {
214 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], ">=")
215 }
216 Comparison::Greater(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], ">"),
217 Comparison::IsNan(op) => reject_complex_operation(scope, [op.input.ty], "is_nan"),
218 Comparison::IsInf(op) => reject_complex_operation(scope, [op.input.ty], "is_inf"),
219 },
220 Operation::Bitwise(bitwise) => match bitwise {
221 Bitwise::BitwiseAnd(op)
222 | Bitwise::BitwiseOr(op)
223 | Bitwise::BitwiseXor(op)
224 | Bitwise::ShiftLeft(op)
225 | Bitwise::ShiftRight(op) => {
226 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "bitwise")
227 }
228 Bitwise::CountOnes(op)
229 | Bitwise::ReverseBits(op)
230 | Bitwise::BitwiseNot(op)
231 | Bitwise::LeadingZeros(op)
232 | Bitwise::TrailingZeros(op)
233 | Bitwise::FindFirstSet(op) => {
234 reject_complex_operation(scope, [op.input.ty], "bitwise")
235 }
236 },
237 Operation::Operator(operator) => match operator {
238 Operator::Real(op) => {
239 require_complex_usage(scope, [op.input.ty], ComplexUsage::Core, "real_val")
240 }
241 Operator::Imag(op) => {
242 require_complex_usage(scope, [op.input.ty], ComplexUsage::Core, "imag_val")
243 }
244 Operator::And(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "&&"),
245 Operator::Or(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "||"),
246 Operator::Not(op) => reject_complex_operation(scope, [op.input.ty], "!"),
247 _ => {}
248 },
249 _ => {}
250 }
251}
252
253pub(crate) fn validate_complex_assign_operation(scope: &mut Scope, operation: &Operation) {
254 match operation {
255 Operation::Arithmetic(Arithmetic::Add(op)) => {
256 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "+=")
257 }
258 Operation::Arithmetic(Arithmetic::Sub(op)) => {
259 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "-=")
260 }
261 Operation::Arithmetic(Arithmetic::Mul(op)) => {
262 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "*=")
263 }
264 Operation::Arithmetic(Arithmetic::Div(op)) => {
265 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "/=")
266 }
267 Operation::Arithmetic(Arithmetic::Modulo(op)) => {
268 reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "%=")
269 }
270 _ => {}
271 }
272}