Skip to main content

cubecl_core/frontend/
validation.rs

1use crate as cubecl;
2use alloc::{
3    format,
4    string::{String, ToString},
5    vec::Vec,
6};
7use cubecl::prelude::*;
8use cubecl_ir::{
9    Arithmetic, Bitwise, Comparison, Operation, Operator, Scope, StorageType, Type,
10    features::ComplexUsage,
11};
12use cubecl_macros::intrinsic;
13
14#[cube]
15#[allow(unused_variables)]
16/// Push a validation error that will make the kernel compilation to fail.
17///
18/// # Notes
19///
20/// The error can be caught after the kernel is launched.
21pub fn push_validation_error(#[comptime] msg: String) {
22    intrinsic! {|scope| scope.push_error(msg)}
23}
24
25fn collect_complex_storage_types(types: impl IntoIterator<Item = Type>) -> Vec<StorageType> {
26    let mut storages = Vec::new();
27
28    for ty in types {
29        if ty.is_semantic() {
30            continue;
31        }
32
33        let storage = ty.storage_type();
34        if storage.elem_type().is_complex() && !storages.contains(&storage) {
35            storages.push(storage);
36        }
37    }
38
39    storages
40}
41
42fn require_complex_usage(
43    scope: &mut Scope,
44    types: impl IntoIterator<Item = Type>,
45    usage: ComplexUsage,
46    op_name: &'static str,
47) {
48    let storages = collect_complex_storage_types(types);
49    if storages.is_empty() {
50        return;
51    }
52
53    // `scope.properties` is populated for all production paths (kernel launch
54    // and compute-builder scopes). When it is unset — only possible from
55    // hand-rolled compiler tests that construct a bare `Scope` — we cannot
56    // decide capability and fall through silently. Such tests must set
57    // `device_properties` if they exercise complex types.
58    let Some(properties) = scope.properties.clone() else {
59        return;
60    };
61
62    for storage in storages {
63        if !properties.supports_complex_usage(storage, usage) {
64            scope.push_error(format!(
65                "Complex operation `{op_name}` requires {usage:?} support for `{storage}`, but the active runtime does not advertise it."
66            ));
67        }
68    }
69}
70
71fn reject_complex_operation(
72    scope: &mut Scope,
73    types: impl IntoIterator<Item = Type>,
74    op_name: &'static str,
75) {
76    let storages = collect_complex_storage_types(types);
77    if storages.is_empty() {
78        return;
79    }
80
81    let supported = storages
82        .into_iter()
83        .map(|storage| storage.to_string())
84        .collect::<Vec<_>>()
85        .join(", ");
86    scope.push_error(format!(
87        "Complex operation `{op_name}` is not part of the v1 complex contract for `{supported}`."
88    ));
89}
90
91pub(crate) fn validate_complex_operation(scope: &mut Scope, operation: &Operation) {
92    match operation {
93        Operation::Arithmetic(arithmetic) => match arithmetic {
94            Arithmetic::Add(op) => {
95                require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Core, "+")
96            }
97            Arithmetic::Sub(op) => {
98                require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Core, "-")
99            }
100            Arithmetic::Mul(op) => {
101                require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Core, "*")
102            }
103            Arithmetic::Div(op) => {
104                require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Core, "/")
105            }
106            Arithmetic::Neg(op) => {
107                require_complex_usage(scope, [op.input.ty], ComplexUsage::Core, "neg")
108            }
109            Arithmetic::Conj(op) => {
110                require_complex_usage(scope, [op.input.ty], ComplexUsage::Core, "conj")
111            }
112            Arithmetic::Abs(op) => {
113                require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "abs")
114            }
115            Arithmetic::Exp(op) => {
116                require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "exp")
117            }
118            Arithmetic::Log(op) => {
119                require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "log")
120            }
121            Arithmetic::Sin(op) => {
122                require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "sin")
123            }
124            Arithmetic::Cos(op) => {
125                require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "cos")
126            }
127            Arithmetic::Sqrt(op) => {
128                require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "sqrt")
129            }
130            Arithmetic::Tanh(op) => {
131                require_complex_usage(scope, [op.input.ty], ComplexUsage::Math, "tanh")
132            }
133            Arithmetic::Powf(op) => {
134                require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Math, "powf")
135            }
136            Arithmetic::Fma(op) => {
137                reject_complex_operation(scope, [op.a.ty, op.b.ty, op.c.ty], "fma")
138            }
139            Arithmetic::SaturatingAdd(op) => {
140                reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "saturating_add")
141            }
142            Arithmetic::SaturatingSub(op) => {
143                reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "saturating_sub")
144            }
145            Arithmetic::Log1p(op) => reject_complex_operation(scope, [op.input.ty], "log1p"),
146            Arithmetic::Expm1(op) => reject_complex_operation(scope, [op.input.ty], "exp_m1"),
147            Arithmetic::Tan(op) => reject_complex_operation(scope, [op.input.ty], "tan"),
148            Arithmetic::Sinh(op) => reject_complex_operation(scope, [op.input.ty], "sinh"),
149            Arithmetic::Cosh(op) => reject_complex_operation(scope, [op.input.ty], "cosh"),
150            Arithmetic::ArcCos(op) => reject_complex_operation(scope, [op.input.ty], "acos"),
151            Arithmetic::ArcSin(op) => reject_complex_operation(scope, [op.input.ty], "asin"),
152            Arithmetic::ArcTan(op) => reject_complex_operation(scope, [op.input.ty], "atan"),
153            Arithmetic::ArcSinh(op) => reject_complex_operation(scope, [op.input.ty], "asinh"),
154            Arithmetic::ArcCosh(op) => reject_complex_operation(scope, [op.input.ty], "acosh"),
155            Arithmetic::ArcTanh(op) => reject_complex_operation(scope, [op.input.ty], "atanh"),
156            Arithmetic::Degrees(op) => reject_complex_operation(scope, [op.input.ty], "to_degrees"),
157            Arithmetic::Radians(op) => reject_complex_operation(scope, [op.input.ty], "to_radians"),
158            Arithmetic::ArcTan2(op) => {
159                reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "atan2")
160            }
161            Arithmetic::Powi(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "powi"),
162            Arithmetic::Hypot(op) => {
163                reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "hypot")
164            }
165            Arithmetic::Rhypot(op) => {
166                reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "rhypot")
167            }
168            Arithmetic::InverseSqrt(op) => {
169                reject_complex_operation(scope, [op.input.ty], "inverse_sqrt")
170            }
171            Arithmetic::Round(op) => reject_complex_operation(scope, [op.input.ty], "round"),
172            Arithmetic::Floor(op) => reject_complex_operation(scope, [op.input.ty], "floor"),
173            Arithmetic::Ceil(op) => reject_complex_operation(scope, [op.input.ty], "ceil"),
174            Arithmetic::Trunc(op) => reject_complex_operation(scope, [op.input.ty], "trunc"),
175            Arithmetic::Erf(op) => reject_complex_operation(scope, [op.input.ty], "erf"),
176            Arithmetic::Recip(op) => reject_complex_operation(scope, [op.input.ty], "recip"),
177            Arithmetic::Clamp(op) => reject_complex_operation(
178                scope,
179                [op.input.ty, op.min_value.ty, op.max_value.ty],
180                "clamp",
181            ),
182            Arithmetic::Modulo(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "%"),
183            Arithmetic::Max(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "max"),
184            Arithmetic::Min(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "min"),
185            Arithmetic::Remainder(op) => {
186                reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "remainder")
187            }
188            Arithmetic::Magnitude(op) => {
189                reject_complex_operation(scope, [op.input.ty], "magnitude")
190            }
191            Arithmetic::Normalize(op) => {
192                reject_complex_operation(scope, [op.input.ty], "normalize")
193            }
194            Arithmetic::Dot(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "dot"),
195            Arithmetic::MulHi(op) => {
196                reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "mul_hi")
197            }
198            Arithmetic::VectorSum(op) => {
199                reject_complex_operation(scope, [op.input.ty], "vector_sum")
200            }
201        },
202        Operation::Comparison(comparison) => match comparison {
203            Comparison::Equal(op) => {
204                require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Compare, "==")
205            }
206            Comparison::NotEqual(op) => {
207                require_complex_usage(scope, [op.lhs.ty, op.rhs.ty], ComplexUsage::Compare, "!=")
208            }
209            Comparison::Lower(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "<"),
210            Comparison::LowerEqual(op) => {
211                reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "<=")
212            }
213            Comparison::GreaterEqual(op) => {
214                reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], ">=")
215            }
216            Comparison::Greater(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], ">"),
217            Comparison::IsNan(op) => reject_complex_operation(scope, [op.input.ty], "is_nan"),
218            Comparison::IsInf(op) => reject_complex_operation(scope, [op.input.ty], "is_inf"),
219        },
220        Operation::Bitwise(bitwise) => match bitwise {
221            Bitwise::BitwiseAnd(op)
222            | Bitwise::BitwiseOr(op)
223            | Bitwise::BitwiseXor(op)
224            | Bitwise::ShiftLeft(op)
225            | Bitwise::ShiftRight(op) => {
226                reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "bitwise")
227            }
228            Bitwise::CountOnes(op)
229            | Bitwise::ReverseBits(op)
230            | Bitwise::BitwiseNot(op)
231            | Bitwise::LeadingZeros(op)
232            | Bitwise::TrailingZeros(op)
233            | Bitwise::FindFirstSet(op) => {
234                reject_complex_operation(scope, [op.input.ty], "bitwise")
235            }
236        },
237        Operation::Operator(operator) => match operator {
238            Operator::Real(op) => {
239                require_complex_usage(scope, [op.input.ty], ComplexUsage::Core, "real_val")
240            }
241            Operator::Imag(op) => {
242                require_complex_usage(scope, [op.input.ty], ComplexUsage::Core, "imag_val")
243            }
244            Operator::And(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "&&"),
245            Operator::Or(op) => reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "||"),
246            Operator::Not(op) => reject_complex_operation(scope, [op.input.ty], "!"),
247            _ => {}
248        },
249        _ => {}
250    }
251}
252
253pub(crate) fn validate_complex_assign_operation(scope: &mut Scope, operation: &Operation) {
254    match operation {
255        Operation::Arithmetic(Arithmetic::Add(op)) => {
256            reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "+=")
257        }
258        Operation::Arithmetic(Arithmetic::Sub(op)) => {
259            reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "-=")
260        }
261        Operation::Arithmetic(Arithmetic::Mul(op)) => {
262            reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "*=")
263        }
264        Operation::Arithmetic(Arithmetic::Div(op)) => {
265            reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "/=")
266        }
267        Operation::Arithmetic(Arithmetic::Modulo(op)) => {
268            reject_complex_operation(scope, [op.lhs.ty, op.rhs.ty], "%=")
269        }
270        _ => {}
271    }
272}