warp_types_kernel/lib.rs
1//! Proc macro for marking GPU kernel functions.
2//!
3//! `#[warp_kernel]` transforms a function into a proper PTX kernel entry point
4//! when compiling for nvptx64, and generates a host-side launcher when compiling
5//! for the host target.
6//!
7//! # Usage
8//!
9//! In your kernel crate (compiled for nvptx64):
10//!
11//! ```rust,ignore
12//! use warp_types::prelude::*;
13//! use warp_types_kernel::warp_kernel;
14//!
15//! #[warp_kernel]
16//! pub fn butterfly_reduce(data: *mut i32) {
17//! let warp: Warp<All> = Warp::kernel_entry();
18//! let tid = warp_types::gpu::thread_id_x();
19//! let mut val = unsafe { *data.add(tid as usize) };
20//!
21//! val += warp.shuffle_xor(PerLane::new(val), 16).get();
22//! val += warp.shuffle_xor(PerLane::new(val), 8).get();
23//! val += warp.shuffle_xor(PerLane::new(val), 4).get();
24//! val += warp.shuffle_xor(PerLane::new(val), 2).get();
25//! val += warp.shuffle_xor(PerLane::new(val), 1).get();
26//!
27//! unsafe { *data.add(tid as usize) = val; }
28//! }
29//! ```
30//!
31//! The macro always emits `#[no_mangle] pub unsafe extern "ptx-kernel" fn ...`
32//! regardless of target. Kernel crates should target nvptx64 exclusively —
33//! the `extern "ptx-kernel"` ABI requires nightly `abi_ptx` and is only
34//! meaningful on GPU targets.
35
36use proc_macro::TokenStream;
37use quote::quote;
38use syn::{parse_macro_input, FnArg, ItemFn, Pat};
39
40/// Mark a function as a GPU kernel entry point.
41///
42/// This attribute transforms the function signature for PTX compilation:
43/// - Adds `#[no_mangle]` for symbol visibility in PTX
44/// - Adds `extern "ptx-kernel"` ABI
45/// - Wraps the body in `unsafe` (PTX kernels are inherently unsafe)
46///
47/// # Parameter Rules
48///
49/// Kernel parameters must be one of:
50/// - Raw pointers (`*const T`, `*mut T`) — for device memory
51/// - Scalars (`u8`, `u16`, `u32`, `u64`, `i8`, `i16`, `i32`, `i64`, `f32`, `f64`, `bool`) — passed by value
52///
53/// Note: `usize`/`isize` are rejected because their width is platform-dependent.
54/// On nvptx64 they are 64-bit, but the host launcher may assume a different size,
55/// causing ABI mismatch. Use explicit-width types (`u32`, `u64`, etc.) instead.
56///
57/// # Compile-Time Safety
58///
59/// The function body uses warp-types normally. `Warp::kernel_entry()` creates
60/// the initial `Warp<All>`, and the type system prevents shuffle-from-inactive-lane
61/// bugs at compile time — on the actual GPU target.
62#[proc_macro_attribute]
63pub fn warp_kernel(_attr: TokenStream, item: TokenStream) -> TokenStream {
64 let input = parse_macro_input!(item as ItemFn);
65
66 let name = &input.sig.ident;
67 let params = &input.sig.inputs;
68 let body = &input.block;
69 let vis = &input.vis;
70 let attrs = &input.attrs;
71
72 // PTX kernels must be void — reject non-unit return types
73 if let syn::ReturnType::Type(_, ref ty) = input.sig.output {
74 let msg = "warp_kernel: GPU kernels must return `()`. \
75 PTX kernel entry points are always void.";
76 return syn::Error::new_spanned(ty, msg).to_compile_error().into();
77 }
78
79 // PTX kernels cannot be generic
80 if !input.sig.generics.params.is_empty() {
81 let msg = "warp_kernel: GPU kernels cannot be generic. \
82 PTX entry points require concrete types.";
83 return syn::Error::new_spanned(&input.sig.generics, msg)
84 .to_compile_error()
85 .into();
86 }
87
88 // Validate parameters: must be raw pointers or scalars
89 for param in params.iter() {
90 if let FnArg::Typed(pat_type) = param {
91 if let Err(err) = validate_kernel_param(&pat_type.ty, &pat_type.pat) {
92 return err;
93 }
94 }
95 }
96
97 // Generate the kernel function for nvptx64
98 // Preserve outer attributes (doc comments, #[cfg], etc.)
99 let expanded = quote! {
100 #(#attrs)*
101 #[no_mangle]
102 #vis unsafe extern "ptx-kernel" fn #name(#params) #body
103 };
104
105 TokenStream::from(expanded)
106}
107
108/// Validate that a kernel parameter type is GPU-compatible.
109///
110/// Returns `Ok(())` if valid, `Err(TokenStream)` with a `compile_error!` if not.
111fn validate_kernel_param(ty: &syn::Type, pat: &Pat) -> Result<(), TokenStream> {
112 match ty {
113 // Raw pointers are always OK
114 syn::Type::Ptr(_) => Ok(()),
115 // Path types: check if they're known scalars
116 syn::Type::Path(tp) => {
117 // Reject qualified paths (e.g., my_crate::u32) — kernel params must be plain scalars
118 if tp.path.segments.len() > 1 {
119 let msg = format!(
120 "warp_kernel: parameter `{}` uses qualified type `{}`. \
121 Use unqualified scalar types (u32, i32, f32, etc.) for kernel parameters.",
122 quote!(#pat),
123 quote!(#ty)
124 );
125 return Err(syn::Error::new_spanned(ty, msg).to_compile_error().into());
126 }
127 if let Some(seg) = tp.path.segments.last() {
128 let name = seg.ident.to_string();
129 let valid_scalars = [
130 "u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64", "f32", "f64", "bool",
131 ];
132 if !valid_scalars.contains(&name.as_str()) {
133 let msg = format!(
134 "warp_kernel: parameter `{}` has type `{}` which is not a GPU-compatible type. \
135 Use raw pointers (*const T, *mut T) for device memory or scalar types (u32, i32, f32, etc.).",
136 quote!(#pat), name
137 );
138 return Err(syn::Error::new_spanned(ty, msg).to_compile_error().into());
139 }
140 }
141 Ok(())
142 }
143 _ => {
144 let msg = format!(
145 "warp_kernel: parameter `{}` has unsupported type `{}`. \
146 Kernel parameters must be raw pointers or scalar types.",
147 quote!(#pat),
148 quote!(#ty)
149 );
150 Err(syn::Error::new_spanned(ty, msg).to_compile_error().into())
151 }
152 }
153}