cubecl_ir/features.rs
1use crate::{AddressType, SemanticType, StorageType, Type};
2use alloc::collections::{BTreeMap, BTreeSet};
3
4use enumset::EnumSetType;
5
6pub use enumset::EnumSet;
7
8/// Features supported by a runtime
9#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
10pub struct Features {
11 /// Plane features supported by this runtime.
12 pub plane: EnumSet<Plane>,
13 /// Clustered launches and intra-cluster operations like cluster shared memory
14 pub cube_cluster: bool,
15 /// Enables changing the type of containers during kernel execution.
16 pub memory_reinterpret: bool,
17 /// Enables explicit alignment. If false, alignment still compiles, but isn't actually applied.
18 pub alignment: bool,
19
20 /// Type support
21 pub types: Types,
22 /// Matrix multiplication features
23 pub matmul: MatmulFeatures,
24
25 /// Whether `copy_async` is supported
26 pub copy_async: bool,
27 /// Tensor Memory Accelerator supported features
28 pub tma: EnumSet<Tma>,
29 /// Whether vectors can be read from / stored to addresses not aligned
30 /// with the `vector_size`
31 pub unaligned_io: bool,
32}
33
34/// Type support for a device
35#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
36pub struct Types {
37 /// Valid address types
38 pub address: BTreeSet<AddressType>,
39 /// Types supported by this runtime, and which usages they support.
40 pub storage: BTreeMap<StorageType, EnumSet<TypeUsage>>,
41 /// Complex-specific capability families supported by this runtime.
42 pub complex: BTreeMap<StorageType, EnumSet<ComplexUsage>>,
43 /// Semantic constructs supported by this runtime.
44 pub semantic: BTreeSet<SemanticType>,
45 /// Supported vector types for atomic ops, only specific vectorizations for specific types are
46 /// supported here. Not all vector types are supported as scalars, i.e. Vulkan on Nvidia only
47 /// supports vectorized `f16`, not scalar. Only use the exact vectorizations registered here.
48 /// These may not be supported everywhere - in practice, f32 vectors are only supported in global
49 /// memory.
50 pub atomic: BTreeMap<Type, EnumSet<AtomicUsage>>,
51}
52
53/// Matrix multiplication-related features
54#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
55pub struct MatmulFeatures {
56 /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
57 pub cmma: BTreeSet<MmaConfig>,
58 /// The manual MMA feature enables cooperative matrix-multiply with manually managed data
59 /// movement
60 pub mma: BTreeSet<MmaConfig>,
61 /// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
62 /// instruction. Scales must fit a specific layout and block size.
63 pub scaled_mma: BTreeSet<ScaledMmaConfig>,
64 /// Types supported for ldmatrix, if any
65 pub ldmatrix: BTreeSet<StorageType>,
66 /// Types supported by stmatrix, if any
67 pub stmatrix: BTreeSet<StorageType>,
68}
69
70/// Operations allowed for this type. CMMA is defined separately.
71#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
72pub enum TypeUsage {
73 /// Conversion to/from the type. All types should support this.
74 Conversion,
75 /// All math/logic instructions except dot product
76 Arithmetic,
77 /// Dot product, mainly for BF16 on Intel
78 DotProduct,
79 /// Whether this type can be stored in a buffer
80 Buffer,
81}
82
83/// Complex capability families allowed for a complex storage type.
84#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
85pub enum ComplexUsage {
86 /// Core ML-centric complex functionality: arithmetic, negation, conjugation, real/imag.
87 Core,
88 /// Equality and inequality comparisons.
89 Compare,
90 /// Higher-level math functions such as exp/log/sin/cos/sqrt/tanh/powf and abs.
91 Math,
92}
93
94impl TypeUsage {
95 pub fn all() -> EnumSet<Self> {
96 EnumSet::all()
97 }
98
99 pub fn no_store() -> EnumSet<Self> {
100 TypeUsage::Conversion | TypeUsage::Arithmetic
101 }
102
103 pub fn maybe_store(storable: bool) -> EnumSet<Self> {
104 if storable {
105 EnumSet::all()
106 } else {
107 Self::no_store()
108 }
109 }
110}
111
112/// Atomic operations allowed for this type.
113#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
114pub enum AtomicUsage {
115 /// Atomic loads and stores
116 LoadStore,
117 /// Atomic add/sub
118 Add,
119 /// Atomic min/max
120 MinMax,
121}
122
123impl AtomicUsage {
124 pub fn all() -> EnumSet<Self> {
125 EnumSet::all()
126 }
127}
128
129/// Supported plane features
130#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
131pub enum Plane {
132 /// Basic plane-wide operations
133 Ops,
134 /// Plane-wide sync
135 Sync,
136 /// Allows using plane operations with divergent control flow.
137 NonUniformControlFlow,
138}
139
140/// Shape and element types of a valid MMA configuration
141#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
142#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
143pub struct MmaConfig {
144 /// Element of the A matrix
145 pub a_type: StorageType,
146 /// Element of the B matrix
147 pub b_type: StorageType,
148 /// Element of the C/D matrices
149 pub cd_type: StorageType,
150 /// The size of the matrix on the `m` dimension
151 pub m: u32,
152 /// The size of the matrix on the `n` dimension
153 pub n: u32,
154 /// The size of the matrix on the `k` dimension
155 pub k: u32,
156}
157
158/// Shape and element types of a valid block-scaled MMA configuration
159#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
160#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
161pub struct ScaledMmaConfig {
162 /// Element of the A matrix
163 pub a_type: StorageType,
164 /// Element of the B matrix
165 pub b_type: StorageType,
166 /// Element of the C/D matrices
167 pub cd_type: StorageType,
168 /// Element of the blocks scales
169 pub scales_type: StorageType,
170 /// The size of the matrix on the `m` dimension
171 pub m: u32,
172 /// The size of the matrix on the `n` dimension
173 pub n: u32,
174 /// The size of the matrix on the `k` dimension
175 pub k: u32,
176 /// Number of scales per tile row/col.
177 /// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
178 /// Scales blocks must be organized along the natural `vector_layout` of the operation
179 pub scales_factor: u32,
180}
181
182/// Atomic features that may be supported by a ``Runtime``.
183#[derive(Debug, PartialOrd, Ord, EnumSetType)]
184pub enum Tma {
185 /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
186 Base,
187 /// im2colWide encoding for tensor map.
188 Im2colWide,
189 /// Different atomicities for 128-byte swizzle, i.e. 128-byte with 32-byte atomicity.
190 SwizzleAtomicity,
191}
192
193impl Features {
194 /// Get the usages for a type
195 pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
196 self.types
197 .storage
198 .get(&ty)
199 .cloned()
200 .unwrap_or_else(EnumSet::empty)
201 }
202
203 /// Get the complex capability families for a type.
204 pub fn complex_usage(&self, ty: StorageType) -> EnumSet<ComplexUsage> {
205 self.types
206 .complex
207 .get(&ty)
208 .cloned()
209 .unwrap_or_else(EnumSet::empty)
210 }
211
212 /// Get the usages for an atomic type
213 pub fn atomic_type_usage(&self, ty: Type) -> EnumSet<AtomicUsage> {
214 self.types
215 .atomic
216 .get(&ty)
217 .cloned()
218 .unwrap_or_else(EnumSet::empty)
219 }
220
221 /// Whether the type is supported in any way
222 pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
223 match ty.into() {
224 Type::Scalar(storage_type) | Type::Vector(storage_type, _) => {
225 self.types.storage.contains_key(&storage_type)
226 }
227 Type::Semantic(semantic_type) => self.types.semantic.contains(&semantic_type),
228 }
229 }
230
231 /// Whether the address type is supported in any way
232 pub fn supports_address(&self, ty: impl Into<AddressType>) -> bool {
233 self.types.address.contains(&ty.into())
234 }
235
236 /// Whether a complex storage type supports the requested capability family.
237 pub fn supports_complex_usage(&self, ty: impl Into<StorageType>, usage: ComplexUsage) -> bool {
238 self.complex_usage(ty.into()).contains(usage)
239 }
240}