spirv_std_macros/
lib.rs

1// BEGIN - Embark standard lints v0.4
2// do not change or add/remove here, but one can add exceptions after this section
3// for more info see: <https://github.com/EmbarkStudios/rust-ecosystem/issues/59>
4#![deny(unsafe_code)]
5#![warn(
6    clippy::all,
7    clippy::await_holding_lock,
8    clippy::char_lit_as_u8,
9    clippy::checked_conversions,
10    clippy::dbg_macro,
11    clippy::debug_assert_with_mut_call,
12    clippy::doc_markdown,
13    clippy::empty_enum,
14    clippy::enum_glob_use,
15    clippy::exit,
16    clippy::expl_impl_clone_on_copy,
17    clippy::explicit_deref_methods,
18    clippy::explicit_into_iter_loop,
19    clippy::fallible_impl_from,
20    clippy::filter_map_next,
21    clippy::float_cmp_const,
22    clippy::fn_params_excessive_bools,
23    clippy::if_let_mutex,
24    clippy::implicit_clone,
25    clippy::imprecise_flops,
26    clippy::inefficient_to_string,
27    clippy::invalid_upcast_comparisons,
28    clippy::large_types_passed_by_value,
29    clippy::let_unit_value,
30    clippy::linkedlist,
31    clippy::lossy_float_literal,
32    clippy::macro_use_imports,
33    clippy::manual_ok_or,
34    clippy::map_err_ignore,
35    clippy::map_flatten,
36    clippy::map_unwrap_or,
37    clippy::match_on_vec_items,
38    clippy::match_same_arms,
39    clippy::match_wildcard_for_single_variants,
40    clippy::mem_forget,
41    clippy::mismatched_target_os,
42    clippy::mut_mut,
43    clippy::mutex_integer,
44    clippy::needless_borrow,
45    clippy::needless_continue,
46    clippy::option_option,
47    clippy::path_buf_push_overwrite,
48    clippy::ptr_as_ptr,
49    clippy::ref_option_ref,
50    clippy::rest_pat_in_fully_bound_structs,
51    clippy::same_functions_in_if_condition,
52    clippy::semicolon_if_nothing_returned,
53    clippy::string_add_assign,
54    clippy::string_add,
55    clippy::string_lit_as_bytes,
56    clippy::string_to_string,
57    clippy::todo,
58    clippy::trait_duplication_in_bounds,
59    clippy::unimplemented,
60    clippy::unnested_or_patterns,
61    clippy::unused_self,
62    clippy::useless_transmute,
63    clippy::verbose_file_reads,
64    clippy::zero_sized_map_values,
65    future_incompatible,
66    nonstandard_style,
67    rust_2018_idioms
68)]
69// END - Embark standard lints v0.4
70// crate-specific exceptions:
71// #![allow()]
72#![doc = include_str!("../README.md")]
73
74mod image;
75
76use proc_macro::TokenStream;
77use proc_macro2::{Delimiter, Group, Ident, Span, TokenTree};
78
79use syn::{punctuated::Punctuated, spanned::Spanned, visit_mut::VisitMut, ItemFn, Token};
80
81use quote::{quote, ToTokens};
82use std::fmt::Write;
83
84/// A macro for creating SPIR-V `OpTypeImage` types. Always produces a
85/// `spirv_std::image::Image<...>` type.
86///
87/// The grammar for the macro is as follows:
88///
89/// ```rust,ignore
90/// Image!(
91///     <dimensionality>,
92///     <type=...|format=...>,
93///     [sampled[=<true|false>],]
94///     [multisampled[=<true|false>],]
95///     [arrayed[=<true|false>],]
96///     [depth[=<true|false>],]
97/// )
98/// ```
99///
100/// `=true` can be omitted as shorthand - e.g. `sampled` is short for `sampled=true`.
101///
102/// A basic example looks like this:
103/// ```rust,ignore
104/// #[spirv(vertex)]
105/// fn main(#[spirv(descriptor_set = 0, binding = 0)] image: &Image!(2D, type=f32, sampled)) {}
106/// ```
107///
108/// ## Arguments
109///
110/// - `dimensionality` — Dimensionality of an image.
111///    Accepted values: `1D`, `2D`, `3D`, `rect`, `cube`, `subpass`.
112/// - `type` — The sampled type of an image, mutually exclusive with `format`,
113///    when set the image format is unknown.
114///    Accepted values: `f32`, `f64`, `u8`, `u16`, `u32`, `u64`, `i8`, `i16`, `i32`, `i64`.
115/// - `format` — The image format of the image, mutually exclusive with `type`.
116///    Accepted values: Snake case versions of [`ImageFormat`].
117/// - `sampled` — Whether it is known that the image will be used with a sampler.
118///    Accepted values: `true` or `false`. Default: `unknown`.
119/// - `multisampled` — Whether the image contains multisampled content.
120///    Accepted values: `true` or `false`. Default: `false`.
121/// - `arrayed` — Whether the image contains arrayed content.
122///    Accepted values: `true` or `false`. Default: `false`.
123/// - `depth` — Whether it is known that the image is a depth image.
124///    Accepted values: `true` or `false`. Default: `unknown`.
125///
126/// [`ImageFormat`]: spirv_std_types::image_params::ImageFormat
127///
128/// Keep in mind that `sampled` here is a different concept than the `SampledImage` type:
129/// `sampled=true` means that this image requires a sampler to be able to access, while the
130/// `SampledImage` type bundles that sampler together with the image into a single type (e.g.
131/// `sampler2D` in GLSL, vs. `texture2D`).
132#[proc_macro]
133// The `Image` is supposed to be used in the type position, which
134// uses `PascalCase`.
135#[allow(nonstandard_style)]
136pub fn Image(item: TokenStream) -> TokenStream {
137    let output = syn::parse_macro_input!(item as image::ImageType).into_token_stream();
138
139    output.into()
140}
141
142/// Replaces all (nested) occurrences of the `#[spirv(..)]` attribute with
143/// `#[cfg_attr(target_arch="spirv", rust_gpu::spirv(..))]`.
144#[proc_macro_attribute]
145pub fn spirv(attr: TokenStream, item: TokenStream) -> TokenStream {
146    let mut tokens: Vec<TokenTree> = Vec::new();
147
148    // prepend with #[rust_gpu::spirv(..)]
149    let attr: proc_macro2::TokenStream = attr.into();
150    tokens.extend(quote! { #[cfg_attr(target_arch="spirv", rust_gpu::spirv(#attr))] });
151
152    let item: proc_macro2::TokenStream = item.into();
153    for tt in item {
154        match tt {
155            TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
156                let mut sub_tokens = Vec::new();
157                for tt in group.stream() {
158                    match tt {
159                        TokenTree::Group(group)
160                            if group.delimiter() == Delimiter::Bracket
161                                && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv")
162                                && matches!(sub_tokens.last(), Some(TokenTree::Punct(p)) if p.as_char() == '#') =>
163                        {
164                            // group matches [spirv ...]
165                            let inner = group.stream(); // group stream doesn't include the brackets
166                            sub_tokens.extend(
167                                quote! { [cfg_attr(target_arch="spirv", rust_gpu::#inner)] },
168                            );
169                        }
170                        _ => sub_tokens.push(tt),
171                    }
172                }
173                tokens.push(TokenTree::from(Group::new(
174                    Delimiter::Parenthesis,
175                    sub_tokens.into_iter().collect(),
176                )));
177            }
178            _ => tokens.push(tt),
179        }
180    }
181    tokens
182        .into_iter()
183        .collect::<proc_macro2::TokenStream>()
184        .into()
185}
186
187/// Marks a function as runnable only on the GPU, and will panic on
188/// CPU platforms.
189#[proc_macro_attribute]
190pub fn gpu_only(_attr: TokenStream, item: TokenStream) -> TokenStream {
191    let syn::ItemFn {
192        attrs,
193        vis,
194        sig,
195        block,
196    } = syn::parse_macro_input!(item as syn::ItemFn);
197
198    // FIXME(eddyb) this looks like a clippy false positive (`sig` is used below).
199    #[allow(clippy::redundant_clone)]
200    let fn_name = sig.ident.clone();
201
202    let sig_cpu = syn::Signature {
203        abi: None,
204        ..sig.clone()
205    };
206
207    let output = quote::quote! {
208        // Don't warn on unused arguments on the CPU side.
209        #[cfg(not(target_arch="spirv"))]
210        #[allow(unused_variables)]
211        #(#attrs)* #vis #sig_cpu {
212            unimplemented!(concat!("`", stringify!(#fn_name), "` is only available on SPIR-V platforms."))
213        }
214
215        #[cfg(target_arch="spirv")]
216        #(#attrs)* #vis #sig {
217            #block
218        }
219    };
220
221    output.into()
222}
223
224/// Accepts a function with an argument named `component`, and outputs the
225/// function plus a vectorized version of the function which accepts a vector
226/// of `component`. This is mostly useful when you have the same impl body for
227/// a scalar and vector versions of the same operation.
228#[proc_macro_attribute]
229#[doc(hidden)]
230pub fn vectorized(_attr: TokenStream, item: TokenStream) -> TokenStream {
231    let function = syn::parse_macro_input!(item as syn::ItemFn);
232    let vectored_function = match create_vectored_fn(function.clone()) {
233        Ok(val) => val,
234        Err(err) => return err.to_compile_error().into(),
235    };
236
237    let output = quote::quote!(
238        #function
239
240        #vectored_function
241    );
242
243    output.into()
244}
245
246fn create_vectored_fn(
247    ItemFn {
248        attrs,
249        vis,
250        mut sig,
251        block,
252    }: ItemFn,
253) -> Result<ItemFn, syn::Error> {
254    const COMPONENT_ARG_NAME: &str = "component";
255    let trait_bound_name = Ident::new("VECTOR", Span::mixed_site());
256    let const_bound_name = Ident::new("LENGTH", Span::mixed_site());
257
258    sig.ident = Ident::new(&format!("{}_vector", sig.ident), Span::mixed_site());
259    sig.output = syn::ReturnType::Type(
260        Default::default(),
261        Box::new(path_from_ident(trait_bound_name.clone())),
262    );
263
264    let component_type = sig.inputs.iter_mut().find_map(|x| match x {
265        syn::FnArg::Typed(ty) => match &*ty.pat {
266            syn::Pat::Ident(pat) if pat.ident == COMPONENT_ARG_NAME => Some(&mut ty.ty),
267            _ => None,
268        },
269        syn::FnArg::Receiver(_) => None,
270    });
271
272    if component_type.is_none() {
273        return Err(syn::Error::new(
274            sig.inputs.span(),
275            "#[vectorized] requires an argument named `component`.",
276        ));
277    }
278    let component_type = component_type.unwrap();
279
280    let vector_path = {
281        let mut path = syn::Path {
282            leading_colon: None,
283            segments: Punctuated::new(),
284        };
285
286        for segment in &["crate", "vector"] {
287            path.segments
288                .push(Ident::new(segment, Span::mixed_site()).into());
289        }
290
291        path.segments.push(syn::PathSegment {
292            ident: Ident::new("Vector", Span::mixed_site()),
293            arguments: syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
294                colon2_token: None,
295                lt_token: Default::default(),
296                args: {
297                    let mut punct = Punctuated::new();
298
299                    punct.push(syn::GenericArgument::Type(*component_type.clone()));
300                    punct.push(syn::GenericArgument::Type(path_from_ident(
301                        const_bound_name.clone(),
302                    )));
303
304                    punct
305                },
306                gt_token: Default::default(),
307            }),
308        });
309
310        path
311    };
312
313    // Replace the original component type with vector version.
314    **component_type = path_from_ident(trait_bound_name.clone());
315
316    let trait_bounds = {
317        let mut punct = Punctuated::new();
318        punct.push(syn::TypeParamBound::Trait(syn::TraitBound {
319            paren_token: None,
320            modifier: syn::TraitBoundModifier::None,
321            lifetimes: None,
322            path: vector_path,
323        }));
324        punct
325    };
326
327    sig.generics
328        .params
329        .push(syn::GenericParam::Type(syn::TypeParam {
330            attrs: Vec::new(),
331            ident: trait_bound_name,
332            colon_token: Some(Token![:](Span::mixed_site())),
333            bounds: trait_bounds,
334            eq_token: None,
335            default: None,
336        }));
337
338    sig.generics
339        .params
340        .push(syn::GenericParam::Const(syn::ConstParam {
341            attrs: Vec::default(),
342            const_token: Default::default(),
343            ident: const_bound_name,
344            colon_token: Default::default(),
345            ty: syn::Type::Path(syn::TypePath {
346                qself: None,
347                path: Ident::new("usize", Span::mixed_site()).into(),
348            }),
349            eq_token: None,
350            default: None,
351        }));
352
353    Ok(ItemFn {
354        attrs,
355        vis,
356        sig,
357        block,
358    })
359}
360
361fn path_from_ident(ident: Ident) -> syn::Type {
362    syn::Type::Path(syn::TypePath {
363        qself: None,
364        path: syn::Path::from(ident),
365    })
366}
367
368/// Print a formatted string with a newline using the debug printf extension.
369///
370/// Examples:
371///
372/// ```rust,ignore
373/// debug_printfln!("uv: %v2f", uv);
374/// debug_printfln!("pos.x: %f, pos.z: %f, int: %i", pos.x, pos.z, int);
375/// ```
376///
377/// See <https://github.com/KhronosGroup/Vulkan-ValidationLayers/blob/main/docs/debug_printf.md#debug-printf-format-string> for formatting rules.
378#[proc_macro]
379pub fn debug_printf(input: TokenStream) -> TokenStream {
380    debug_printf_inner(syn::parse_macro_input!(input as DebugPrintfInput))
381}
382
383/// Similar to `debug_printf` but appends a newline to the format string.
384#[proc_macro]
385pub fn debug_printfln(input: TokenStream) -> TokenStream {
386    let mut input = syn::parse_macro_input!(input as DebugPrintfInput);
387    input.format_string.push('\n');
388    debug_printf_inner(input)
389}
390
391struct DebugPrintfInput {
392    span: proc_macro2::Span,
393    format_string: String,
394    variables: Vec<syn::Expr>,
395}
396
397impl syn::parse::Parse for DebugPrintfInput {
398    fn parse(input: syn::parse::ParseStream<'_>) -> syn::parse::Result<Self> {
399        let span = input.span();
400
401        if input.is_empty() {
402            return Ok(Self {
403                span,
404                format_string: Default::default(),
405                variables: Default::default(),
406            });
407        }
408
409        let format_string = input.parse::<syn::LitStr>()?;
410        if !input.is_empty() {
411            input.parse::<syn::token::Comma>()?;
412        }
413        let variables =
414            syn::punctuated::Punctuated::<syn::Expr, syn::token::Comma>::parse_terminated(input)?;
415
416        Ok(Self {
417            span,
418            format_string: format_string.value(),
419            variables: variables.into_iter().collect(),
420        })
421    }
422}
423
424fn parsing_error(message: &str, span: proc_macro2::Span) -> TokenStream {
425    syn::Error::new(span, message).to_compile_error().into()
426}
427
428enum FormatType {
429    Scalar {
430        ty: proc_macro2::TokenStream,
431    },
432    Vector {
433        ty: proc_macro2::TokenStream,
434        width: usize,
435    },
436}
437
438fn debug_printf_inner(input: DebugPrintfInput) -> TokenStream {
439    let DebugPrintfInput {
440        format_string,
441        variables,
442        span,
443    } = input;
444
445    fn map_specifier_to_type(
446        specifier: char,
447        chars: &mut std::str::Chars<'_>,
448    ) -> Option<proc_macro2::TokenStream> {
449        let mut peekable = chars.peekable();
450
451        Some(match specifier {
452            'd' | 'i' => quote::quote! { i32 },
453            'o' | 'x' | 'X' => quote::quote! { u32 },
454            'a' | 'A' | 'e' | 'E' | 'f' | 'F' | 'g' | 'G' => quote::quote! { f32 },
455            'u' => {
456                if matches!(peekable.peek(), Some('l')) {
457                    chars.next();
458                    quote::quote! { u64 }
459                } else {
460                    quote::quote! { u32 }
461                }
462            }
463            'l' => {
464                if matches!(peekable.peek(), Some('u' | 'x')) {
465                    chars.next();
466                    quote::quote! { u64 }
467                } else {
468                    return None;
469                }
470            }
471            _ => return None,
472        })
473    }
474
475    let mut chars = format_string.chars();
476    let mut format_arguments = Vec::new();
477
478    while let Some(mut ch) = chars.next() {
479        if ch == '%' {
480            ch = match chars.next() {
481                Some('%') => continue,
482                None => return parsing_error("Unterminated format specifier", span),
483                Some(ch) => ch,
484            };
485
486            let mut has_precision = false;
487
488            while ch.is_ascii_digit() {
489                ch = match chars.next() {
490                    Some(ch) => ch,
491                    None => {
492                        return parsing_error(
493                            "Unterminated format specifier: missing type after precision",
494                            span,
495                        );
496                    }
497                };
498
499                has_precision = true;
500            }
501
502            if has_precision && ch == '.' {
503                ch = match chars.next() {
504                    Some(ch) => ch,
505                    None => {
506                        return parsing_error(
507                            "Unterminated format specifier: missing type after decimal point",
508                            span,
509                        );
510                    }
511                };
512
513                while ch.is_ascii_digit() {
514                    ch = match chars.next() {
515                        Some(ch) => ch,
516                        None => {
517                            return parsing_error(
518                                "Unterminated format specifier: missing type after fraction precision",
519                                span,
520                            );
521                        }
522                    };
523                }
524            }
525
526            if ch == 'v' {
527                let width = match chars.next() {
528                    Some('2') => 2,
529                    Some('3') => 3,
530                    Some('4') => 4,
531                    Some(ch) => {
532                        return parsing_error(&format!("Invalid width for vector: {ch}"), span);
533                    }
534                    None => return parsing_error("Missing vector dimensions specifier", span),
535                };
536
537                ch = match chars.next() {
538                    Some(ch) => ch,
539                    None => return parsing_error("Missing vector type specifier", span),
540                };
541
542                let ty = match map_specifier_to_type(ch, &mut chars) {
543                    Some(ty) => ty,
544                    _ => {
545                        return parsing_error(
546                            &format!("Unrecognised vector type specifier: '{ch}'"),
547                            span,
548                        );
549                    }
550                };
551
552                format_arguments.push(FormatType::Vector { ty, width });
553            } else {
554                let ty = match map_specifier_to_type(ch, &mut chars) {
555                    Some(ty) => ty,
556                    _ => {
557                        return parsing_error(
558                            &format!("Unrecognised format specifier: '{ch}'"),
559                            span,
560                        );
561                    }
562                };
563
564                format_arguments.push(FormatType::Scalar { ty });
565            }
566        }
567    }
568
569    if format_arguments.len() != variables.len() {
570        return syn::Error::new(
571            span,
572            format!(
573                "{} % arguments were found, but {} variables were given",
574                format_arguments.len(),
575                variables.len()
576            ),
577        )
578        .to_compile_error()
579        .into();
580    }
581
582    let mut variable_idents = String::new();
583    let mut input_registers = Vec::new();
584    let mut op_loads = Vec::new();
585
586    for (i, (variable, format_argument)) in variables.into_iter().zip(format_arguments).enumerate()
587    {
588        let ident = quote::format_ident!("_{}", i);
589
590        let _ = write!(variable_idents, "%{ident} ");
591
592        let assert_fn = match format_argument {
593            FormatType::Scalar { ty } => {
594                quote::quote! { spirv_std::debug_printf_assert_is_type::<#ty> }
595            }
596            FormatType::Vector { ty, width } => {
597                quote::quote! { spirv_std::debug_printf_assert_is_vector::<#ty, _, #width> }
598            }
599        };
600
601        input_registers.push(quote::quote! {
602            #ident = in(reg) &#assert_fn(#variable),
603        });
604
605        let op_load = format!("%{ident} = OpLoad _ {{{ident}}}");
606
607        op_loads.push(quote::quote! {
608            #op_load,
609        });
610    }
611
612    let input_registers = input_registers
613        .into_iter()
614        .collect::<proc_macro2::TokenStream>();
615    let op_loads = op_loads.into_iter().collect::<proc_macro2::TokenStream>();
616
617    let op_string = format!("%string = OpString {format_string:?}");
618
619    let output = quote::quote! {
620        ::core::arch::asm!(
621            "%void = OpTypeVoid",
622            #op_string,
623            "%debug_printf = OpExtInstImport \"NonSemantic.DebugPrintf\"",
624            #op_loads
625            concat!("%result = OpExtInst %void %debug_printf 1 %string ", #variable_idents),
626            #input_registers
627        )
628    };
629
630    output.into()
631}
632
633const SAMPLE_PARAM_COUNT: usize = 4;
634const SAMPLE_PARAM_GENERICS: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "G", "S"];
635const SAMPLE_PARAM_TYPES: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "(G,G)", "S"];
636const SAMPLE_PARAM_OPERANDS: [&str; SAMPLE_PARAM_COUNT] = ["Bias", "Lod", "Grad", "Sample"];
637const SAMPLE_PARAM_NAMES: [&str; SAMPLE_PARAM_COUNT] = ["bias", "lod", "grad", "sample_index"];
638const SAMPLE_PARAM_GRAD_INDEX: usize = 2; // Grad requires some special handling because it uses 2 arguments
639const SAMPLE_PARAM_EXPLICIT_LOD_MASK: usize = 0b0110; // which params require the use of ExplicitLod rather than ImplicitLod
640
641fn is_grad(i: usize) -> bool {
642    i == SAMPLE_PARAM_GRAD_INDEX
643}
644
645struct SampleImplRewriter(usize, syn::Type);
646
647impl SampleImplRewriter {
648    pub fn rewrite(mask: usize, f: &syn::ItemImpl) -> syn::ItemImpl {
649        let mut new_impl = f.clone();
650        let mut ty_str = String::from("SampleParams<");
651
652        // based on the mask, form a `SampleParams` type string and add the generic parameters to the `impl<>` generics
653        // example type string: `"SampleParams<SomeTy<B>, NoneTy, NoneTy>"`
654        for i in 0..SAMPLE_PARAM_COUNT {
655            if mask & (1 << i) != 0 {
656                new_impl.generics.params.push(syn::GenericParam::Type(
657                    syn::Ident::new(SAMPLE_PARAM_GENERICS[i], Span::call_site()).into(),
658                ));
659                ty_str.push_str("SomeTy<");
660                ty_str.push_str(SAMPLE_PARAM_TYPES[i]);
661                ty_str.push('>');
662            } else {
663                ty_str.push_str("NoneTy");
664            }
665            ty_str.push(',');
666        }
667        ty_str.push('>');
668        let ty: syn::Type = syn::parse(ty_str.parse().unwrap()).unwrap();
669
670        // use the type to insert it into the generic argument of the trait we're implementing
671        // e.g., `ImageWithMethods<Dummy>` becomes `ImageWithMethods<SampleParams<SomeTy<B>, NoneTy, NoneTy>>`
672        if let Some(t) = &mut new_impl.trait_ {
673            if let syn::PathArguments::AngleBracketed(a) =
674                &mut t.1.segments.last_mut().unwrap().arguments
675            {
676                if let Some(syn::GenericArgument::Type(t)) = a.args.last_mut() {
677                    *t = ty.clone();
678                }
679            }
680        }
681
682        // rewrite the implemented functions
683        SampleImplRewriter(mask, ty).visit_item_impl_mut(&mut new_impl);
684        new_impl
685    }
686
687    // generates an operands string for use in the assembly, e.g. "Bias %bias Lod %lod", based on the mask
688    #[allow(clippy::needless_range_loop)]
689    fn get_operands(&self) -> String {
690        let mut op = String::new();
691        for i in 0..SAMPLE_PARAM_COUNT {
692            if self.0 & (1 << i) != 0 {
693                if is_grad(i) {
694                    op.push_str("Grad %grad_x %grad_y ");
695                } else {
696                    op.push_str(SAMPLE_PARAM_OPERANDS[i]);
697                    op.push_str(" %");
698                    op.push_str(SAMPLE_PARAM_NAMES[i]);
699                    op.push(' ');
700                }
701            }
702        }
703        op
704    }
705
706    // generates list of assembly loads for the data, e.g. "%bias = OpLoad _ {bias}", etc.
707    #[allow(clippy::needless_range_loop)]
708    fn add_loads(&self, t: &mut Vec<TokenTree>) {
709        for i in 0..SAMPLE_PARAM_COUNT {
710            if self.0 & (1 << i) != 0 {
711                if is_grad(i) {
712                    t.push(TokenTree::Literal(proc_macro2::Literal::string(
713                        "%grad_x = OpLoad _ {grad_x}",
714                    )));
715                    t.push(TokenTree::Punct(proc_macro2::Punct::new(
716                        ',',
717                        proc_macro2::Spacing::Alone,
718                    )));
719                    t.push(TokenTree::Literal(proc_macro2::Literal::string(
720                        "%grad_y = OpLoad _ {grad_y}",
721                    )));
722                    t.push(TokenTree::Punct(proc_macro2::Punct::new(
723                        ',',
724                        proc_macro2::Spacing::Alone,
725                    )));
726                } else {
727                    let s = format!("%{0} = OpLoad _ {{{0}}}", SAMPLE_PARAM_NAMES[i]);
728                    t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
729                    t.push(TokenTree::Punct(proc_macro2::Punct::new(
730                        ',',
731                        proc_macro2::Spacing::Alone,
732                    )));
733                }
734            }
735        }
736    }
737
738    // generates list of register specifications, e.g. `bias = in(reg) &params.bias.0, ...` as separate tokens
739    #[allow(clippy::needless_range_loop)]
740    fn add_regs(&self, t: &mut Vec<TokenTree>) {
741        for i in 0..SAMPLE_PARAM_COUNT {
742            if self.0 & (1 << i) != 0 {
743                let s = if is_grad(i) {
744                    String::from("grad_x=in(reg) &params.grad.0.0,grad_y=in(reg) &params.grad.0.1,")
745                } else {
746                    format!("{0} = in(reg) &params.{0}.0,", SAMPLE_PARAM_NAMES[i])
747                };
748                let ts: proc_macro2::TokenStream = s.parse().unwrap();
749                t.extend(ts);
750            }
751        }
752    }
753}
754
755impl VisitMut for SampleImplRewriter {
756    fn visit_impl_item_method_mut(&mut self, item: &mut syn::ImplItemMethod) {
757        // rewrite the last parameter of this method to be of type `SampleParams<...>` we generated earlier
758        if let Some(syn::FnArg::Typed(p)) = item.sig.inputs.last_mut() {
759            *p.ty.as_mut() = self.1.clone();
760        }
761        syn::visit_mut::visit_impl_item_method_mut(self, item);
762    }
763
764    fn visit_macro_mut(&mut self, m: &mut syn::Macro) {
765        if m.path.is_ident("asm") {
766            // this is where the asm! block is manipulated
767            let t = m.tokens.clone();
768            let mut new_t = Vec::new();
769            let mut altered = false;
770
771            for tt in t {
772                match tt {
773                    TokenTree::Literal(l) => {
774                        if let Ok(l) = syn::parse::<syn::LitStr>(l.to_token_stream().into()) {
775                            // found a string literal
776                            let s = l.value();
777                            if s.contains("$PARAMS") {
778                                altered = true;
779                                // add load instructions before the sampling instruction
780                                self.add_loads(&mut new_t);
781                                // and insert image operands
782                                let s = s.replace("$PARAMS", &self.get_operands());
783                                let lod_type = if self.0 & SAMPLE_PARAM_EXPLICIT_LOD_MASK != 0 {
784                                    "ExplicitLod"
785                                } else {
786                                    "ImplicitLod "
787                                };
788                                let s = s.replace("$LOD", lod_type);
789
790                                new_t.push(TokenTree::Literal(proc_macro2::Literal::string(
791                                    s.as_str(),
792                                )));
793                            } else {
794                                new_t.push(TokenTree::Literal(l.token()));
795                            }
796                        } else {
797                            new_t.push(TokenTree::Literal(l));
798                        }
799                    }
800                    _ => {
801                        new_t.push(tt);
802                    }
803                }
804            }
805
806            if altered {
807                // finally, add register specs
808                self.add_regs(&mut new_t);
809            }
810
811            // replace all tokens within the asm! block with our new list
812            m.tokens = new_t.into_iter().collect();
813        }
814    }
815}
816
817/// Generates permutations of an `ImageWithMethods` implementation containing sampling functions
818/// that have asm instruction ending with a placeholder `$PARAMS` operand. The last parameter
819/// of each function must be named `params`, its type will be rewritten. Relevant generic
820/// arguments are added to the impl generics.
821/// See `SAMPLE_PARAM_GENERICS` for a list of names you cannot use as generic arguments.
822#[proc_macro_attribute]
823#[doc(hidden)]
824pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
825    let item_impl = syn::parse_macro_input!(item as syn::ItemImpl);
826    let mut fns = Vec::new();
827
828    for m in 1..(1 << SAMPLE_PARAM_COUNT) {
829        fns.push(SampleImplRewriter::rewrite(m, &item_impl));
830    }
831
832    // uncomment to output generated tokenstream to stdout
833    //println!("{}", quote! { #(#fns)* }.to_string());
834    quote! { #(#fns)* }.into()
835}