Skip to main content

warp_types_kernel/
lib.rs

1//! Proc macro for marking GPU kernel functions.
2//!
3//! `#[warp_kernel]` transforms a function into a proper PTX kernel entry point
4//! when compiling for nvptx64, and generates a host-side launcher when compiling
5//! for the host target.
6//!
7//! # Usage
8//!
9//! In your kernel crate (compiled for nvptx64):
10//!
11//! ```rust,ignore
12//! use warp_types::prelude::*;
13//! use warp_types_kernel::warp_kernel;
14//!
15//! #[warp_kernel]
16//! pub fn butterfly_reduce(data: *mut i32) {
17//!     let warp: Warp<All> = Warp::kernel_entry();
18//!     let tid = warp_types::gpu::thread_id_x();
19//!     let mut val = unsafe { *data.add(tid as usize) };
20//!
21//!     val += warp.shuffle_xor(PerLane::new(val), 16).get();
22//!     val += warp.shuffle_xor(PerLane::new(val), 8).get();
23//!     val += warp.shuffle_xor(PerLane::new(val), 4).get();
24//!     val += warp.shuffle_xor(PerLane::new(val), 2).get();
25//!     val += warp.shuffle_xor(PerLane::new(val), 1).get();
26//!
27//!     unsafe { *data.add(tid as usize) = val; }
28//! }
29//! ```
30//!
31//! The macro always emits `#[no_mangle] pub unsafe extern "ptx-kernel" fn ...`
32//! regardless of target. Kernel crates should target nvptx64 exclusively —
33//! the `extern "ptx-kernel"` ABI requires nightly `abi_ptx` and is only
34//! meaningful on GPU targets.
35
36use proc_macro::TokenStream;
37use quote::quote;
38use syn::{parse_macro_input, FnArg, ItemFn, Pat};
39
40/// Mark a function as a GPU kernel entry point.
41///
42/// This attribute transforms the function signature for PTX compilation:
43/// - Adds `#[no_mangle]` for symbol visibility in PTX
44/// - Adds `extern "ptx-kernel"` ABI
45/// - Wraps the body in `unsafe` (PTX kernels are inherently unsafe)
46///
47/// # Parameter Rules
48///
49/// Kernel parameters must be one of:
50/// - Raw pointers (`*const T`, `*mut T`) — for device memory
51/// - Scalars (`u8`, `u16`, `u32`, `u64`, `i8`, `i16`, `i32`, `i64`, `f32`, `f64`, `bool`) — passed by value
52///
53/// Note: `usize`/`isize` are rejected because their width is platform-dependent.
54/// On nvptx64 they are 64-bit, but the host launcher may assume a different size,
55/// causing ABI mismatch. Use explicit-width types (`u32`, `u64`, etc.) instead.
56///
57/// # Compile-Time Safety
58///
59/// The function body uses warp-types normally. `Warp::kernel_entry()` creates
60/// the initial `Warp<All>`, and the type system prevents shuffle-from-inactive-lane
61/// bugs at compile time — on the actual GPU target.
62#[proc_macro_attribute]
63pub fn warp_kernel(_attr: TokenStream, item: TokenStream) -> TokenStream {
64    let input = parse_macro_input!(item as ItemFn);
65
66    let name = &input.sig.ident;
67    let params = &input.sig.inputs;
68    let body = &input.block;
69    let vis = &input.vis;
70    let attrs = &input.attrs;
71
72    // PTX kernels must be void — reject non-unit return types
73    if let syn::ReturnType::Type(_, ref ty) = input.sig.output {
74        let msg = "warp_kernel: GPU kernels must return `()`. \
75                   PTX kernel entry points are always void.";
76        return syn::Error::new_spanned(ty, msg).to_compile_error().into();
77    }
78
79    // PTX kernels cannot be generic
80    if !input.sig.generics.params.is_empty() {
81        let msg = "warp_kernel: GPU kernels cannot be generic. \
82                   PTX entry points require concrete types.";
83        return syn::Error::new_spanned(&input.sig.generics, msg)
84            .to_compile_error()
85            .into();
86    }
87
88    // Validate parameters: must be raw pointers or scalars
89    for param in params.iter() {
90        if let FnArg::Typed(pat_type) = param {
91            if let Err(err) = validate_kernel_param(&pat_type.ty, &pat_type.pat) {
92                return err;
93            }
94        }
95    }
96
97    // Generate the kernel function for nvptx64
98    // Preserve outer attributes (doc comments, #[cfg], etc.)
99    let expanded = quote! {
100        #(#attrs)*
101        #[no_mangle]
102        #vis unsafe extern "ptx-kernel" fn #name(#params) #body
103    };
104
105    TokenStream::from(expanded)
106}
107
108/// Validate that a kernel parameter type is GPU-compatible.
109///
110/// Returns `Ok(())` if valid, `Err(TokenStream)` with a `compile_error!` if not.
111fn validate_kernel_param(ty: &syn::Type, pat: &Pat) -> Result<(), TokenStream> {
112    match ty {
113        // Raw pointers are always OK
114        syn::Type::Ptr(_) => Ok(()),
115        // Path types: check if they're known scalars
116        syn::Type::Path(tp) => {
117            // Reject qualified paths (e.g., my_crate::u32) — kernel params must be plain scalars
118            if tp.path.segments.len() > 1 {
119                let msg = format!(
120                    "warp_kernel: parameter `{}` uses qualified type `{}`. \
121                     Use unqualified scalar types (u32, i32, f32, etc.) for kernel parameters.",
122                    quote!(#pat),
123                    quote!(#ty)
124                );
125                return Err(syn::Error::new_spanned(ty, msg).to_compile_error().into());
126            }
127            if let Some(seg) = tp.path.segments.last() {
128                let name = seg.ident.to_string();
129                let valid_scalars = [
130                    "u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64", "f32", "f64", "bool",
131                ];
132                if !valid_scalars.contains(&name.as_str()) {
133                    let msg = format!(
134                        "warp_kernel: parameter `{}` has type `{}` which is not a GPU-compatible type. \
135                         Use raw pointers (*const T, *mut T) for device memory or scalar types (u32, i32, f32, etc.).",
136                        quote!(#pat), name
137                    );
138                    return Err(syn::Error::new_spanned(ty, msg).to_compile_error().into());
139                }
140            }
141            Ok(())
142        }
143        _ => {
144            let msg = format!(
145                "warp_kernel: parameter `{}` has unsupported type `{}`. \
146                 Kernel parameters must be raw pointers or scalar types.",
147                quote!(#pat),
148                quote!(#ty)
149            );
150            Err(syn::Error::new_spanned(ty, msg).to_compile_error().into())
151        }
152    }
153}