1#![deny(unsafe_code)]
5#![warn(
6 clippy::all,
7 clippy::await_holding_lock,
8 clippy::char_lit_as_u8,
9 clippy::checked_conversions,
10 clippy::dbg_macro,
11 clippy::debug_assert_with_mut_call,
12 clippy::doc_markdown,
13 clippy::empty_enum,
14 clippy::enum_glob_use,
15 clippy::exit,
16 clippy::expl_impl_clone_on_copy,
17 clippy::explicit_deref_methods,
18 clippy::explicit_into_iter_loop,
19 clippy::fallible_impl_from,
20 clippy::filter_map_next,
21 clippy::float_cmp_const,
22 clippy::fn_params_excessive_bools,
23 clippy::if_let_mutex,
24 clippy::implicit_clone,
25 clippy::imprecise_flops,
26 clippy::inefficient_to_string,
27 clippy::invalid_upcast_comparisons,
28 clippy::large_types_passed_by_value,
29 clippy::let_unit_value,
30 clippy::linkedlist,
31 clippy::lossy_float_literal,
32 clippy::macro_use_imports,
33 clippy::manual_ok_or,
34 clippy::map_err_ignore,
35 clippy::map_flatten,
36 clippy::map_unwrap_or,
37 clippy::match_on_vec_items,
38 clippy::match_same_arms,
39 clippy::match_wildcard_for_single_variants,
40 clippy::mem_forget,
41 clippy::mismatched_target_os,
42 clippy::mut_mut,
43 clippy::mutex_integer,
44 clippy::needless_borrow,
45 clippy::needless_continue,
46 clippy::option_option,
47 clippy::path_buf_push_overwrite,
48 clippy::ptr_as_ptr,
49 clippy::ref_option_ref,
50 clippy::rest_pat_in_fully_bound_structs,
51 clippy::same_functions_in_if_condition,
52 clippy::semicolon_if_nothing_returned,
53 clippy::string_add_assign,
54 clippy::string_add,
55 clippy::string_lit_as_bytes,
56 clippy::string_to_string,
57 clippy::todo,
58 clippy::trait_duplication_in_bounds,
59 clippy::unimplemented,
60 clippy::unnested_or_patterns,
61 clippy::unused_self,
62 clippy::useless_transmute,
63 clippy::verbose_file_reads,
64 clippy::zero_sized_map_values,
65 future_incompatible,
66 nonstandard_style,
67 rust_2018_idioms
68)]
69#![doc = include_str!("../README.md")]
73
74mod image;
75
76use proc_macro::TokenStream;
77use proc_macro2::{Delimiter, Group, Ident, Span, TokenTree};
78
79use syn::{punctuated::Punctuated, spanned::Spanned, visit_mut::VisitMut, ItemFn, Token};
80
81use quote::{quote, ToTokens};
82use std::fmt::Write;
83
84#[proc_macro]
133#[allow(nonstandard_style)]
136pub fn Image(item: TokenStream) -> TokenStream {
137 let output = syn::parse_macro_input!(item as image::ImageType).into_token_stream();
138
139 output.into()
140}
141
142#[proc_macro_attribute]
145pub fn spirv(attr: TokenStream, item: TokenStream) -> TokenStream {
146 let mut tokens: Vec<TokenTree> = Vec::new();
147
148 let attr: proc_macro2::TokenStream = attr.into();
150 tokens.extend(quote! { #[cfg_attr(target_arch="spirv", rust_gpu::spirv(#attr))] });
151
152 let item: proc_macro2::TokenStream = item.into();
153 for tt in item {
154 match tt {
155 TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
156 let mut sub_tokens = Vec::new();
157 for tt in group.stream() {
158 match tt {
159 TokenTree::Group(group)
160 if group.delimiter() == Delimiter::Bracket
161 && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv")
162 && matches!(sub_tokens.last(), Some(TokenTree::Punct(p)) if p.as_char() == '#') =>
163 {
164 let inner = group.stream(); sub_tokens.extend(
167 quote! { [cfg_attr(target_arch="spirv", rust_gpu::#inner)] },
168 );
169 }
170 _ => sub_tokens.push(tt),
171 }
172 }
173 tokens.push(TokenTree::from(Group::new(
174 Delimiter::Parenthesis,
175 sub_tokens.into_iter().collect(),
176 )));
177 }
178 _ => tokens.push(tt),
179 }
180 }
181 tokens
182 .into_iter()
183 .collect::<proc_macro2::TokenStream>()
184 .into()
185}
186
187#[proc_macro_attribute]
190pub fn gpu_only(_attr: TokenStream, item: TokenStream) -> TokenStream {
191 let syn::ItemFn {
192 attrs,
193 vis,
194 sig,
195 block,
196 } = syn::parse_macro_input!(item as syn::ItemFn);
197
198 #[allow(clippy::redundant_clone)]
200 let fn_name = sig.ident.clone();
201
202 let sig_cpu = syn::Signature {
203 abi: None,
204 ..sig.clone()
205 };
206
207 let output = quote::quote! {
208 #[cfg(not(target_arch="spirv"))]
210 #[allow(unused_variables)]
211 #(#attrs)* #vis #sig_cpu {
212 unimplemented!(concat!("`", stringify!(#fn_name), "` is only available on SPIR-V platforms."))
213 }
214
215 #[cfg(target_arch="spirv")]
216 #(#attrs)* #vis #sig {
217 #block
218 }
219 };
220
221 output.into()
222}
223
224#[proc_macro_attribute]
229#[doc(hidden)]
230pub fn vectorized(_attr: TokenStream, item: TokenStream) -> TokenStream {
231 let function = syn::parse_macro_input!(item as syn::ItemFn);
232 let vectored_function = match create_vectored_fn(function.clone()) {
233 Ok(val) => val,
234 Err(err) => return err.to_compile_error().into(),
235 };
236
237 let output = quote::quote!(
238 #function
239
240 #vectored_function
241 );
242
243 output.into()
244}
245
246fn create_vectored_fn(
247 ItemFn {
248 attrs,
249 vis,
250 mut sig,
251 block,
252 }: ItemFn,
253) -> Result<ItemFn, syn::Error> {
254 const COMPONENT_ARG_NAME: &str = "component";
255 let trait_bound_name = Ident::new("VECTOR", Span::mixed_site());
256 let const_bound_name = Ident::new("LENGTH", Span::mixed_site());
257
258 sig.ident = Ident::new(&format!("{}_vector", sig.ident), Span::mixed_site());
259 sig.output = syn::ReturnType::Type(
260 Default::default(),
261 Box::new(path_from_ident(trait_bound_name.clone())),
262 );
263
264 let component_type = sig.inputs.iter_mut().find_map(|x| match x {
265 syn::FnArg::Typed(ty) => match &*ty.pat {
266 syn::Pat::Ident(pat) if pat.ident == COMPONENT_ARG_NAME => Some(&mut ty.ty),
267 _ => None,
268 },
269 syn::FnArg::Receiver(_) => None,
270 });
271
272 if component_type.is_none() {
273 return Err(syn::Error::new(
274 sig.inputs.span(),
275 "#[vectorized] requires an argument named `component`.",
276 ));
277 }
278 let component_type = component_type.unwrap();
279
280 let vector_path = {
281 let mut path = syn::Path {
282 leading_colon: None,
283 segments: Punctuated::new(),
284 };
285
286 for segment in &["crate", "vector"] {
287 path.segments
288 .push(Ident::new(segment, Span::mixed_site()).into());
289 }
290
291 path.segments.push(syn::PathSegment {
292 ident: Ident::new("Vector", Span::mixed_site()),
293 arguments: syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
294 colon2_token: None,
295 lt_token: Default::default(),
296 args: {
297 let mut punct = Punctuated::new();
298
299 punct.push(syn::GenericArgument::Type(*component_type.clone()));
300 punct.push(syn::GenericArgument::Type(path_from_ident(
301 const_bound_name.clone(),
302 )));
303
304 punct
305 },
306 gt_token: Default::default(),
307 }),
308 });
309
310 path
311 };
312
313 **component_type = path_from_ident(trait_bound_name.clone());
315
316 let trait_bounds = {
317 let mut punct = Punctuated::new();
318 punct.push(syn::TypeParamBound::Trait(syn::TraitBound {
319 paren_token: None,
320 modifier: syn::TraitBoundModifier::None,
321 lifetimes: None,
322 path: vector_path,
323 }));
324 punct
325 };
326
327 sig.generics
328 .params
329 .push(syn::GenericParam::Type(syn::TypeParam {
330 attrs: Vec::new(),
331 ident: trait_bound_name,
332 colon_token: Some(Token)),
333 bounds: trait_bounds,
334 eq_token: None,
335 default: None,
336 }));
337
338 sig.generics
339 .params
340 .push(syn::GenericParam::Const(syn::ConstParam {
341 attrs: Vec::default(),
342 const_token: Default::default(),
343 ident: const_bound_name,
344 colon_token: Default::default(),
345 ty: syn::Type::Path(syn::TypePath {
346 qself: None,
347 path: Ident::new("usize", Span::mixed_site()).into(),
348 }),
349 eq_token: None,
350 default: None,
351 }));
352
353 Ok(ItemFn {
354 attrs,
355 vis,
356 sig,
357 block,
358 })
359}
360
361fn path_from_ident(ident: Ident) -> syn::Type {
362 syn::Type::Path(syn::TypePath {
363 qself: None,
364 path: syn::Path::from(ident),
365 })
366}
367
368#[proc_macro]
379pub fn debug_printf(input: TokenStream) -> TokenStream {
380 debug_printf_inner(syn::parse_macro_input!(input as DebugPrintfInput))
381}
382
383#[proc_macro]
385pub fn debug_printfln(input: TokenStream) -> TokenStream {
386 let mut input = syn::parse_macro_input!(input as DebugPrintfInput);
387 input.format_string.push('\n');
388 debug_printf_inner(input)
389}
390
391struct DebugPrintfInput {
392 span: proc_macro2::Span,
393 format_string: String,
394 variables: Vec<syn::Expr>,
395}
396
397impl syn::parse::Parse for DebugPrintfInput {
398 fn parse(input: syn::parse::ParseStream<'_>) -> syn::parse::Result<Self> {
399 let span = input.span();
400
401 if input.is_empty() {
402 return Ok(Self {
403 span,
404 format_string: Default::default(),
405 variables: Default::default(),
406 });
407 }
408
409 let format_string = input.parse::<syn::LitStr>()?;
410 if !input.is_empty() {
411 input.parse::<syn::token::Comma>()?;
412 }
413 let variables =
414 syn::punctuated::Punctuated::<syn::Expr, syn::token::Comma>::parse_terminated(input)?;
415
416 Ok(Self {
417 span,
418 format_string: format_string.value(),
419 variables: variables.into_iter().collect(),
420 })
421 }
422}
423
424fn parsing_error(message: &str, span: proc_macro2::Span) -> TokenStream {
425 syn::Error::new(span, message).to_compile_error().into()
426}
427
428enum FormatType {
429 Scalar {
430 ty: proc_macro2::TokenStream,
431 },
432 Vector {
433 ty: proc_macro2::TokenStream,
434 width: usize,
435 },
436}
437
438fn debug_printf_inner(input: DebugPrintfInput) -> TokenStream {
439 let DebugPrintfInput {
440 format_string,
441 variables,
442 span,
443 } = input;
444
445 fn map_specifier_to_type(
446 specifier: char,
447 chars: &mut std::str::Chars<'_>,
448 ) -> Option<proc_macro2::TokenStream> {
449 let mut peekable = chars.peekable();
450
451 Some(match specifier {
452 'd' | 'i' => quote::quote! { i32 },
453 'o' | 'x' | 'X' => quote::quote! { u32 },
454 'a' | 'A' | 'e' | 'E' | 'f' | 'F' | 'g' | 'G' => quote::quote! { f32 },
455 'u' => {
456 if matches!(peekable.peek(), Some('l')) {
457 chars.next();
458 quote::quote! { u64 }
459 } else {
460 quote::quote! { u32 }
461 }
462 }
463 'l' => {
464 if matches!(peekable.peek(), Some('u' | 'x')) {
465 chars.next();
466 quote::quote! { u64 }
467 } else {
468 return None;
469 }
470 }
471 _ => return None,
472 })
473 }
474
475 let mut chars = format_string.chars();
476 let mut format_arguments = Vec::new();
477
478 while let Some(mut ch) = chars.next() {
479 if ch == '%' {
480 ch = match chars.next() {
481 Some('%') => continue,
482 None => return parsing_error("Unterminated format specifier", span),
483 Some(ch) => ch,
484 };
485
486 let mut has_precision = false;
487
488 while ch.is_ascii_digit() {
489 ch = match chars.next() {
490 Some(ch) => ch,
491 None => {
492 return parsing_error(
493 "Unterminated format specifier: missing type after precision",
494 span,
495 );
496 }
497 };
498
499 has_precision = true;
500 }
501
502 if has_precision && ch == '.' {
503 ch = match chars.next() {
504 Some(ch) => ch,
505 None => {
506 return parsing_error(
507 "Unterminated format specifier: missing type after decimal point",
508 span,
509 );
510 }
511 };
512
513 while ch.is_ascii_digit() {
514 ch = match chars.next() {
515 Some(ch) => ch,
516 None => {
517 return parsing_error(
518 "Unterminated format specifier: missing type after fraction precision",
519 span,
520 );
521 }
522 };
523 }
524 }
525
526 if ch == 'v' {
527 let width = match chars.next() {
528 Some('2') => 2,
529 Some('3') => 3,
530 Some('4') => 4,
531 Some(ch) => {
532 return parsing_error(&format!("Invalid width for vector: {ch}"), span);
533 }
534 None => return parsing_error("Missing vector dimensions specifier", span),
535 };
536
537 ch = match chars.next() {
538 Some(ch) => ch,
539 None => return parsing_error("Missing vector type specifier", span),
540 };
541
542 let ty = match map_specifier_to_type(ch, &mut chars) {
543 Some(ty) => ty,
544 _ => {
545 return parsing_error(
546 &format!("Unrecognised vector type specifier: '{ch}'"),
547 span,
548 );
549 }
550 };
551
552 format_arguments.push(FormatType::Vector { ty, width });
553 } else {
554 let ty = match map_specifier_to_type(ch, &mut chars) {
555 Some(ty) => ty,
556 _ => {
557 return parsing_error(
558 &format!("Unrecognised format specifier: '{ch}'"),
559 span,
560 );
561 }
562 };
563
564 format_arguments.push(FormatType::Scalar { ty });
565 }
566 }
567 }
568
569 if format_arguments.len() != variables.len() {
570 return syn::Error::new(
571 span,
572 format!(
573 "{} % arguments were found, but {} variables were given",
574 format_arguments.len(),
575 variables.len()
576 ),
577 )
578 .to_compile_error()
579 .into();
580 }
581
582 let mut variable_idents = String::new();
583 let mut input_registers = Vec::new();
584 let mut op_loads = Vec::new();
585
586 for (i, (variable, format_argument)) in variables.into_iter().zip(format_arguments).enumerate()
587 {
588 let ident = quote::format_ident!("_{}", i);
589
590 let _ = write!(variable_idents, "%{ident} ");
591
592 let assert_fn = match format_argument {
593 FormatType::Scalar { ty } => {
594 quote::quote! { spirv_std::debug_printf_assert_is_type::<#ty> }
595 }
596 FormatType::Vector { ty, width } => {
597 quote::quote! { spirv_std::debug_printf_assert_is_vector::<#ty, _, #width> }
598 }
599 };
600
601 input_registers.push(quote::quote! {
602 #ident = in(reg) &#assert_fn(#variable),
603 });
604
605 let op_load = format!("%{ident} = OpLoad _ {{{ident}}}");
606
607 op_loads.push(quote::quote! {
608 #op_load,
609 });
610 }
611
612 let input_registers = input_registers
613 .into_iter()
614 .collect::<proc_macro2::TokenStream>();
615 let op_loads = op_loads.into_iter().collect::<proc_macro2::TokenStream>();
616
617 let op_string = format!("%string = OpString {format_string:?}");
618
619 let output = quote::quote! {
620 ::core::arch::asm!(
621 "%void = OpTypeVoid",
622 #op_string,
623 "%debug_printf = OpExtInstImport \"NonSemantic.DebugPrintf\"",
624 #op_loads
625 concat!("%result = OpExtInst %void %debug_printf 1 %string ", #variable_idents),
626 #input_registers
627 )
628 };
629
630 output.into()
631}
632
633const SAMPLE_PARAM_COUNT: usize = 4;
634const SAMPLE_PARAM_GENERICS: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "G", "S"];
635const SAMPLE_PARAM_TYPES: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "(G,G)", "S"];
636const SAMPLE_PARAM_OPERANDS: [&str; SAMPLE_PARAM_COUNT] = ["Bias", "Lod", "Grad", "Sample"];
637const SAMPLE_PARAM_NAMES: [&str; SAMPLE_PARAM_COUNT] = ["bias", "lod", "grad", "sample_index"];
638const SAMPLE_PARAM_GRAD_INDEX: usize = 2; const SAMPLE_PARAM_EXPLICIT_LOD_MASK: usize = 0b0110; fn is_grad(i: usize) -> bool {
642 i == SAMPLE_PARAM_GRAD_INDEX
643}
644
645struct SampleImplRewriter(usize, syn::Type);
646
647impl SampleImplRewriter {
648 pub fn rewrite(mask: usize, f: &syn::ItemImpl) -> syn::ItemImpl {
649 let mut new_impl = f.clone();
650 let mut ty_str = String::from("SampleParams<");
651
652 for i in 0..SAMPLE_PARAM_COUNT {
655 if mask & (1 << i) != 0 {
656 new_impl.generics.params.push(syn::GenericParam::Type(
657 syn::Ident::new(SAMPLE_PARAM_GENERICS[i], Span::call_site()).into(),
658 ));
659 ty_str.push_str("SomeTy<");
660 ty_str.push_str(SAMPLE_PARAM_TYPES[i]);
661 ty_str.push('>');
662 } else {
663 ty_str.push_str("NoneTy");
664 }
665 ty_str.push(',');
666 }
667 ty_str.push('>');
668 let ty: syn::Type = syn::parse(ty_str.parse().unwrap()).unwrap();
669
670 if let Some(t) = &mut new_impl.trait_ {
673 if let syn::PathArguments::AngleBracketed(a) =
674 &mut t.1.segments.last_mut().unwrap().arguments
675 {
676 if let Some(syn::GenericArgument::Type(t)) = a.args.last_mut() {
677 *t = ty.clone();
678 }
679 }
680 }
681
682 SampleImplRewriter(mask, ty).visit_item_impl_mut(&mut new_impl);
684 new_impl
685 }
686
687 #[allow(clippy::needless_range_loop)]
689 fn get_operands(&self) -> String {
690 let mut op = String::new();
691 for i in 0..SAMPLE_PARAM_COUNT {
692 if self.0 & (1 << i) != 0 {
693 if is_grad(i) {
694 op.push_str("Grad %grad_x %grad_y ");
695 } else {
696 op.push_str(SAMPLE_PARAM_OPERANDS[i]);
697 op.push_str(" %");
698 op.push_str(SAMPLE_PARAM_NAMES[i]);
699 op.push(' ');
700 }
701 }
702 }
703 op
704 }
705
706 #[allow(clippy::needless_range_loop)]
708 fn add_loads(&self, t: &mut Vec<TokenTree>) {
709 for i in 0..SAMPLE_PARAM_COUNT {
710 if self.0 & (1 << i) != 0 {
711 if is_grad(i) {
712 t.push(TokenTree::Literal(proc_macro2::Literal::string(
713 "%grad_x = OpLoad _ {grad_x}",
714 )));
715 t.push(TokenTree::Punct(proc_macro2::Punct::new(
716 ',',
717 proc_macro2::Spacing::Alone,
718 )));
719 t.push(TokenTree::Literal(proc_macro2::Literal::string(
720 "%grad_y = OpLoad _ {grad_y}",
721 )));
722 t.push(TokenTree::Punct(proc_macro2::Punct::new(
723 ',',
724 proc_macro2::Spacing::Alone,
725 )));
726 } else {
727 let s = format!("%{0} = OpLoad _ {{{0}}}", SAMPLE_PARAM_NAMES[i]);
728 t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
729 t.push(TokenTree::Punct(proc_macro2::Punct::new(
730 ',',
731 proc_macro2::Spacing::Alone,
732 )));
733 }
734 }
735 }
736 }
737
738 #[allow(clippy::needless_range_loop)]
740 fn add_regs(&self, t: &mut Vec<TokenTree>) {
741 for i in 0..SAMPLE_PARAM_COUNT {
742 if self.0 & (1 << i) != 0 {
743 let s = if is_grad(i) {
744 String::from("grad_x=in(reg) ¶ms.grad.0.0,grad_y=in(reg) ¶ms.grad.0.1,")
745 } else {
746 format!("{0} = in(reg) ¶ms.{0}.0,", SAMPLE_PARAM_NAMES[i])
747 };
748 let ts: proc_macro2::TokenStream = s.parse().unwrap();
749 t.extend(ts);
750 }
751 }
752 }
753}
754
755impl VisitMut for SampleImplRewriter {
756 fn visit_impl_item_method_mut(&mut self, item: &mut syn::ImplItemMethod) {
757 if let Some(syn::FnArg::Typed(p)) = item.sig.inputs.last_mut() {
759 *p.ty.as_mut() = self.1.clone();
760 }
761 syn::visit_mut::visit_impl_item_method_mut(self, item);
762 }
763
764 fn visit_macro_mut(&mut self, m: &mut syn::Macro) {
765 if m.path.is_ident("asm") {
766 let t = m.tokens.clone();
768 let mut new_t = Vec::new();
769 let mut altered = false;
770
771 for tt in t {
772 match tt {
773 TokenTree::Literal(l) => {
774 if let Ok(l) = syn::parse::<syn::LitStr>(l.to_token_stream().into()) {
775 let s = l.value();
777 if s.contains("$PARAMS") {
778 altered = true;
779 self.add_loads(&mut new_t);
781 let s = s.replace("$PARAMS", &self.get_operands());
783 let lod_type = if self.0 & SAMPLE_PARAM_EXPLICIT_LOD_MASK != 0 {
784 "ExplicitLod"
785 } else {
786 "ImplicitLod "
787 };
788 let s = s.replace("$LOD", lod_type);
789
790 new_t.push(TokenTree::Literal(proc_macro2::Literal::string(
791 s.as_str(),
792 )));
793 } else {
794 new_t.push(TokenTree::Literal(l.token()));
795 }
796 } else {
797 new_t.push(TokenTree::Literal(l));
798 }
799 }
800 _ => {
801 new_t.push(tt);
802 }
803 }
804 }
805
806 if altered {
807 self.add_regs(&mut new_t);
809 }
810
811 m.tokens = new_t.into_iter().collect();
813 }
814 }
815}
816
817#[proc_macro_attribute]
823#[doc(hidden)]
824pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
825 let item_impl = syn::parse_macro_input!(item as syn::ItemImpl);
826 let mut fns = Vec::new();
827
828 for m in 1..(1 << SAMPLE_PARAM_COUNT) {
829 fns.push(SampleImplRewriter::rewrite(m, &item_impl));
830 }
831
832 quote! { #(#fns)* }.into()
835}