1use std::sync::{Arc, RwLock, RwLockReadGuard};
25
26use derive_more::{From, IsVariant, Unwrap};
27
28use crate::span::Spanned;
29
30#[cfg(feature = "serde")]
31use serde::{Deserialize, Serialize};
32
33#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
34#[derive(Default, Clone, Debug, PartialEq)]
35pub struct TranslationUnit {
36 #[cfg(feature = "imports")]
37 pub imports: Vec<ImportStatement>,
38 pub global_directives: Vec<GlobalDirective>,
39 pub global_declarations: Vec<GlobalDeclaration>,
40}
41
42#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
49#[derive(Clone, Debug)]
50pub struct Ident(Arc<RwLock<String>>);
51
52impl Ident {
53 pub fn new(name: String) -> Ident {
55 Ident(Arc::new(RwLock::new(name)))
56 }
57 pub fn name(&self) -> RwLockReadGuard<'_, String> {
59 self.0.read().unwrap()
60 }
61 pub fn rename(&mut self, name: String) {
63 *self.0.write().unwrap() = name;
64 }
65 pub fn use_count(&self) -> usize {
67 Arc::<_>::strong_count(&self.0)
68 }
69}
70
71impl PartialEq for Ident {
73 fn eq(&self, other: &Self) -> bool {
74 Arc::ptr_eq(&self.0, &other.0)
75 }
76}
77
78impl Eq for Ident {}
80
81impl std::hash::Hash for Ident {
83 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
84 std::ptr::hash(&*self.0, state)
85 }
86}
87
88#[cfg(feature = "imports")]
89#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
90#[derive(Clone, Debug, PartialEq)]
91pub struct ImportStatement {
92 #[cfg(feature = "attributes")]
93 pub attributes: Attributes,
94 pub path: ModulePath,
95 pub content: ImportContent,
96}
97
98#[cfg(feature = "imports")]
99#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
100#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IsVariant)]
101pub enum PathOrigin {
102 Absolute,
103 Relative(usize),
104 Package,
105}
106
107#[cfg(feature = "imports")]
108#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
109#[derive(Clone, Debug, PartialEq, Eq, Hash)]
110pub struct ModulePath {
111 pub origin: PathOrigin,
112 pub components: Vec<String>,
113}
114
115#[cfg(feature = "imports")]
116#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
117#[derive(Clone, Debug, PartialEq)]
118pub struct Import {
119 pub path: Vec<String>,
120 pub content: ImportContent,
121}
122
123#[cfg(feature = "imports")]
124#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
125#[derive(Clone, Debug, PartialEq, IsVariant)]
126pub enum ImportContent {
127 Item(ImportItem),
128 Collection(Vec<Import>),
129}
130
131#[cfg(feature = "imports")]
132#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
133#[derive(Clone, Debug, PartialEq)]
134pub struct ImportItem {
135 pub ident: Ident,
136 pub rename: Option<Ident>,
137}
138
139#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
140#[derive(Clone, Debug, PartialEq, From, IsVariant, Unwrap)]
141pub enum GlobalDirective {
142 Diagnostic(DiagnosticDirective),
143 Enable(EnableDirective),
144 Requires(RequiresDirective),
145}
146
147#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
148#[derive(Clone, Debug, PartialEq)]
149pub struct DiagnosticDirective {
150 #[cfg(feature = "attributes")]
151 pub attributes: Attributes,
152 pub severity: DiagnosticSeverity,
153 pub rule_name: String,
154}
155
156#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
157#[derive(Clone, Debug, PartialEq, Eq, IsVariant)]
158pub enum DiagnosticSeverity {
159 Error,
160 Warning,
161 Info,
162 Off,
163}
164
165#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
166#[derive(Clone, Debug, PartialEq)]
167pub struct EnableDirective {
168 #[cfg(feature = "attributes")]
169 pub attributes: Attributes,
170 pub extensions: Vec<String>,
171}
172
173#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
174#[derive(Clone, Debug, PartialEq)]
175pub struct RequiresDirective {
176 #[cfg(feature = "attributes")]
177 pub attributes: Attributes,
178 pub extensions: Vec<String>,
179}
180
181#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
182#[derive(Clone, Debug, PartialEq, From, IsVariant, Unwrap)]
183pub enum GlobalDeclaration {
184 Void,
185 Declaration(Declaration),
186 TypeAlias(TypeAlias),
187 Struct(Struct),
188 Function(Function),
189 ConstAssert(ConstAssert),
190}
191
192#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
193#[derive(Clone, Debug, PartialEq)]
194pub struct Declaration {
195 pub attributes: Attributes,
196 pub kind: DeclarationKind,
197 pub ident: Ident,
198 pub ty: Option<TypeExpression>,
199 pub initializer: Option<ExpressionNode>,
200}
201
202#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
203#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
204pub enum DeclarationKind {
205 Const,
206 Override,
207 Let,
208 Var(Option<AddressSpace>), }
210
211#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
212#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
213pub enum AddressSpace {
214 Function,
215 Private,
216 Workgroup,
217 Uniform,
218 Storage(Option<AccessMode>),
219 Handle, }
221
222#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
223#[derive(Clone, Copy, Debug, PartialEq, Eq)]
224pub enum AccessMode {
225 Read,
226 Write,
227 ReadWrite,
228}
229
230#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
231#[derive(Clone, Debug, PartialEq)]
232pub struct TypeAlias {
233 #[cfg(feature = "attributes")]
234 pub attributes: Attributes,
235 pub ident: Ident,
236 pub ty: TypeExpression,
237}
238
239#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
240#[derive(Clone, Debug, PartialEq)]
241pub struct Struct {
242 #[cfg(feature = "attributes")]
243 pub attributes: Attributes,
244 pub ident: Ident,
245 pub members: Vec<StructMember>,
246}
247
248#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
249#[derive(Clone, Debug, PartialEq)]
250pub struct StructMember {
251 pub attributes: Attributes,
252 pub ident: Ident,
253 pub ty: TypeExpression,
254}
255
256#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
257#[derive(Clone, Debug, PartialEq)]
258pub struct Function {
259 pub attributes: Attributes,
260 pub ident: Ident,
261 pub parameters: Vec<FormalParameter>,
262 pub return_attributes: Attributes,
263 pub return_type: Option<TypeExpression>,
264 pub body: CompoundStatement,
265}
266
267#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
268#[derive(Clone, Debug, PartialEq)]
269pub struct FormalParameter {
270 pub attributes: Attributes,
271 pub ident: Ident,
272 pub ty: TypeExpression,
273}
274
275#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
276#[derive(Clone, Debug, PartialEq)]
277pub struct ConstAssert {
278 #[cfg(feature = "attributes")]
279 pub attributes: Attributes,
280 pub expression: ExpressionNode,
281}
282
283#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
284#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
285pub enum BuiltinValue {
286 VertexIndex,
287 InstanceIndex,
288 Position,
289 FrontFacing,
290 FragDepth,
291 SampleIndex,
292 SampleMask,
293 LocalInvocationId,
294 LocalInvocationIndex,
295 GlobalInvocationId,
296 WorkgroupId,
297 NumWorkgroups,
298}
299
300#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
301#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
302pub enum InterpolationType {
303 Perspective,
304 Linear,
305 Flat,
306}
307
308#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
309#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
310pub enum InterpolationSampling {
311 Center,
312 Centroid,
313 Sample,
314 First,
315 Either,
316}
317
318#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
319#[derive(Clone, Debug, PartialEq)]
320pub struct DiagnosticAttribute {
321 pub severity: DiagnosticSeverity,
322 pub rule: String,
323}
324
325#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
326#[derive(Clone, Debug, PartialEq)]
327pub struct InterpolateAttribute {
328 pub ty: InterpolationType,
329 pub sampling: Option<InterpolationSampling>,
330}
331
332#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
333#[derive(Clone, Debug, PartialEq)]
334pub struct WorkgroupSizeAttribute {
335 pub x: ExpressionNode,
336 pub y: Option<ExpressionNode>,
337 pub z: Option<ExpressionNode>,
338}
339
340#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
341#[derive(Clone, Debug, PartialEq)]
342pub struct CustomAttribute {
343 pub name: String,
344 pub arguments: Option<Vec<ExpressionNode>>,
345}
346
347#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
348#[derive(Clone, Debug, PartialEq, IsVariant, Unwrap)]
349pub enum Attribute {
350 Align(ExpressionNode),
351 Binding(ExpressionNode),
352 BlendSrc(ExpressionNode),
353 Builtin(BuiltinValue),
354 Const,
355 Diagnostic(DiagnosticAttribute),
356 Group(ExpressionNode),
357 Id(ExpressionNode),
358 Interpolate(InterpolateAttribute),
359 Invariant,
360 Location(ExpressionNode),
361 MustUse,
362 Size(ExpressionNode),
363 WorkgroupSize(WorkgroupSizeAttribute),
364 Vertex,
365 Fragment,
366 Compute,
367 #[cfg(feature = "condcomp")]
368 If(ExpressionNode),
369 #[cfg(feature = "condcomp")]
370 Elif(ExpressionNode),
371 #[cfg(feature = "condcomp")]
372 Else,
373 #[cfg(feature = "generics")]
374 Type(TypeConstraint),
375 Custom(CustomAttribute),
376}
377
378#[cfg(feature = "generics")]
379#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
380#[derive(Clone, Debug, PartialEq, From)]
381pub struct TypeConstraint {
382 pub ident: Ident,
383 pub variants: Vec<TypeExpression>,
384}
385
386pub type Attributes = Vec<Attribute>;
387
388#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
389#[derive(Clone, Debug, PartialEq, From, IsVariant, Unwrap)]
390pub enum Expression {
391 Literal(LiteralExpression),
392 Parenthesized(ParenthesizedExpression),
393 NamedComponent(NamedComponentExpression),
394 Indexing(IndexingExpression),
395 Unary(UnaryExpression),
396 Binary(BinaryExpression),
397 FunctionCall(FunctionCallExpression),
398 TypeOrIdentifier(TypeExpression),
399}
400
401pub type ExpressionNode = Spanned<Expression>;
402
403#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
404#[derive(Clone, Copy, Debug, PartialEq, From, IsVariant, Unwrap)]
405pub enum LiteralExpression {
406 Bool(bool),
407 AbstractInt(i64),
408 AbstractFloat(f64),
409 I32(i32),
410 U32(u32),
411 F32(f32),
412 #[from(skip)]
413 F16(f32),
414}
415
416#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
417#[derive(Clone, Debug, PartialEq)]
418pub struct ParenthesizedExpression {
419 pub expression: ExpressionNode,
420}
421
422#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
423#[derive(Clone, Debug, PartialEq)]
424pub struct NamedComponentExpression {
425 pub base: ExpressionNode,
426 pub component: Ident,
427}
428
429#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
430#[derive(Clone, Debug, PartialEq)]
431pub struct IndexingExpression {
432 pub base: ExpressionNode,
433 pub index: ExpressionNode,
434}
435
436#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
437#[derive(Clone, Debug, PartialEq)]
438pub struct UnaryExpression {
439 pub operator: UnaryOperator,
440 pub operand: ExpressionNode,
441}
442
443#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
444#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
445pub enum UnaryOperator {
446 LogicalNegation,
447 Negation,
448 BitwiseComplement,
449 AddressOf,
450 Indirection,
451}
452
453#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
454#[derive(Clone, Debug, PartialEq)]
455pub struct BinaryExpression {
456 pub operator: BinaryOperator,
457 pub left: ExpressionNode,
458 pub right: ExpressionNode,
459}
460
461#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
462#[derive(Clone, Copy, Debug, PartialEq, Eq, IsVariant)]
463pub enum BinaryOperator {
464 ShortCircuitOr,
465 ShortCircuitAnd,
466 Addition,
467 Subtraction,
468 Multiplication,
469 Division,
470 Remainder,
471 Equality,
472 Inequality,
473 LessThan,
474 LessThanEqual,
475 GreaterThan,
476 GreaterThanEqual,
477 BitwiseOr,
478 BitwiseAnd,
479 BitwiseXor,
480 ShiftLeft,
481 ShiftRight,
482}
483
484#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
485#[derive(Clone, Debug, PartialEq)]
486pub struct FunctionCall {
487 pub ty: TypeExpression,
488 pub arguments: Vec<ExpressionNode>,
489}
490
491pub type FunctionCallExpression = FunctionCall;
492
493#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
494#[derive(Clone, Debug, PartialEq)]
495pub struct TypeExpression {
496 #[cfg(feature = "imports")]
497 pub path: Option<ModulePath>,
498 pub ident: Ident,
499 pub template_args: TemplateArgs,
500}
501
502#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
503#[derive(Clone, Debug, PartialEq)]
504pub struct TemplateArg {
505 pub expression: ExpressionNode,
506}
507pub type TemplateArgs = Option<Vec<TemplateArg>>;
508
509#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
510#[derive(Clone, Debug, PartialEq, From, IsVariant, Unwrap)]
511pub enum Statement {
512 Void,
513 Compound(CompoundStatement),
514 Assignment(AssignmentStatement),
515 Increment(IncrementStatement),
516 Decrement(DecrementStatement),
517 If(IfStatement),
518 Switch(SwitchStatement),
519 Loop(LoopStatement),
520 For(ForStatement),
521 While(WhileStatement),
522 Break(BreakStatement),
523 Continue(ContinueStatement),
524 Return(ReturnStatement),
525 Discard(DiscardStatement),
526 FunctionCall(FunctionCallStatement),
527 ConstAssert(ConstAssertStatement),
528 Declaration(DeclarationStatement),
529}
530
531pub type StatementNode = Spanned<Statement>;
532
533#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
534#[derive(Clone, Debug, PartialEq)]
535pub struct CompoundStatement {
536 pub attributes: Attributes,
537 pub statements: Vec<StatementNode>,
538}
539
540#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
541#[derive(Clone, Debug, PartialEq)]
542pub struct AssignmentStatement {
543 #[cfg(feature = "attributes")]
544 pub attributes: Attributes,
545 pub operator: AssignmentOperator,
546 pub lhs: ExpressionNode,
547 pub rhs: ExpressionNode,
548}
549
550#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
551#[derive(Clone, Debug, PartialEq, Eq, IsVariant)]
552pub enum AssignmentOperator {
553 Equal,
554 PlusEqual,
555 MinusEqual,
556 TimesEqual,
557 DivisionEqual,
558 ModuloEqual,
559 AndEqual,
560 OrEqual,
561 XorEqual,
562 ShiftRightAssign,
563 ShiftLeftAssign,
564}
565
566#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
567#[derive(Clone, Debug, PartialEq)]
568pub struct IncrementStatement {
569 #[cfg(feature = "attributes")]
570 pub attributes: Attributes,
571 pub expression: ExpressionNode,
572}
573
574#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
575#[derive(Clone, Debug, PartialEq)]
576pub struct DecrementStatement {
577 #[cfg(feature = "attributes")]
578 pub attributes: Attributes,
579 pub expression: ExpressionNode,
580}
581
582#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
583#[derive(Clone, Debug, PartialEq)]
584pub struct IfStatement {
585 pub attributes: Attributes,
586 pub if_clause: IfClause,
587 pub else_if_clauses: Vec<ElseIfClause>,
588 pub else_clause: Option<ElseClause>,
589}
590
591#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
592#[derive(Clone, Debug, PartialEq)]
593pub struct IfClause {
594 pub expression: ExpressionNode,
595 pub body: CompoundStatement,
596}
597
598#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
599#[derive(Clone, Debug, PartialEq)]
600pub struct ElseIfClause {
601 #[cfg(feature = "attributes")]
602 pub attributes: Attributes,
603 pub expression: ExpressionNode,
604 pub body: CompoundStatement,
605}
606
607#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
608#[derive(Clone, Debug, PartialEq)]
609pub struct ElseClause {
610 #[cfg(feature = "attributes")]
611 pub attributes: Attributes,
612 pub body: CompoundStatement,
613}
614
615#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
616#[derive(Clone, Debug, PartialEq)]
617pub struct SwitchStatement {
618 pub attributes: Attributes,
619 pub expression: ExpressionNode,
620 pub body_attributes: Attributes,
621 pub clauses: Vec<SwitchClause>,
622}
623
624#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
625#[derive(Clone, Debug, PartialEq)]
626pub struct SwitchClause {
627 #[cfg(feature = "attributes")]
628 pub attributes: Attributes,
629 pub case_selectors: Vec<CaseSelector>,
630 pub body: CompoundStatement,
631}
632
633#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
634#[derive(Clone, Debug, PartialEq, From, IsVariant, Unwrap)]
635pub enum CaseSelector {
636 Default,
637 Expression(ExpressionNode),
638}
639
640#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
641#[derive(Clone, Debug, PartialEq)]
642pub struct LoopStatement {
643 pub attributes: Attributes,
644 pub body: CompoundStatement,
645 pub continuing: Option<ContinuingStatement>,
649}
650
651#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
652#[derive(Clone, Debug, PartialEq)]
653pub struct ContinuingStatement {
654 #[cfg(feature = "attributes")]
655 pub attributes: Attributes,
656 pub body: CompoundStatement,
657 pub break_if: Option<BreakIfStatement>,
661}
662
663#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
664#[derive(Clone, Debug, PartialEq)]
665pub struct BreakIfStatement {
666 #[cfg(feature = "attributes")]
667 pub attributes: Attributes,
668 pub expression: ExpressionNode,
669}
670
671#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
672#[derive(Clone, Debug, PartialEq)]
673pub struct ForStatement {
674 pub attributes: Attributes,
675 pub initializer: Option<StatementNode>,
676 pub condition: Option<ExpressionNode>,
677 pub update: Option<StatementNode>,
678 pub body: CompoundStatement,
679}
680
681#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
682#[derive(Clone, Debug, PartialEq)]
683pub struct WhileStatement {
684 pub attributes: Attributes,
685 pub condition: ExpressionNode,
686 pub body: CompoundStatement,
687}
688
689#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
690#[derive(Clone, Debug, PartialEq)]
691pub struct BreakStatement {
692 #[cfg(feature = "attributes")]
693 pub attributes: Attributes,
694}
695
696#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
697#[derive(Clone, Debug, PartialEq)]
698pub struct ContinueStatement {
699 #[cfg(feature = "attributes")]
700 pub attributes: Attributes,
701}
702
703#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
704#[derive(Clone, Debug, PartialEq)]
705pub struct ReturnStatement {
706 #[cfg(feature = "attributes")]
707 pub attributes: Attributes,
708 pub expression: Option<ExpressionNode>,
709}
710
711#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
712#[derive(Clone, Debug, PartialEq)]
713pub struct DiscardStatement {
714 #[cfg(feature = "attributes")]
715 pub attributes: Attributes,
716}
717
718#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
719#[derive(Clone, Debug, PartialEq)]
720pub struct FunctionCallStatement {
721 #[cfg(feature = "attributes")]
722 pub attributes: Attributes,
723 pub call: FunctionCall,
724}
725
726pub type ConstAssertStatement = ConstAssert;
727
728pub type DeclarationStatement = Declaration;