Skip to main content

warp_types_kernel/
lib.rs

1//! Proc macro for marking GPU kernel functions.
2//!
3//! `#[warp_kernel]` transforms a function into a proper PTX kernel entry point
4//! when compiling for nvptx64, and generates a host-side launcher when compiling
5//! for the host target.
6//!
7//! # Usage
8//!
9//! In your kernel crate (compiled for nvptx64):
10//!
11//! ```rust,ignore
12//! use warp_types::prelude::*;
13//! use warp_types_kernel::warp_kernel;
14//!
15//! #[warp_kernel]
16//! pub fn butterfly_reduce(data: *mut i32) {
17//!     let warp: Warp<All> = Warp::kernel_entry();
18//!     let tid = warp_types::gpu::thread_id_x();
19//!     let mut val = unsafe { *data.add(tid as usize) };
20//!
21//!     val += warp.shuffle_xor(PerLane::new(val), 16).get();
22//!     val += warp.shuffle_xor(PerLane::new(val), 8).get();
23//!     val += warp.shuffle_xor(PerLane::new(val), 4).get();
24//!     val += warp.shuffle_xor(PerLane::new(val), 2).get();
25//!     val += warp.shuffle_xor(PerLane::new(val), 1).get();
26//!
27//!     unsafe { *data.add(tid as usize) = val; }
28//! }
29//! ```
30//!
31//! The macro emits:
32//! - On nvptx64: `#[no_mangle] pub unsafe extern "ptx-kernel" fn butterfly_reduce(...)`
33//! - On host: nothing (kernel functions are only compiled for GPU)
34
35use proc_macro::TokenStream;
36use quote::quote;
37use syn::{parse_macro_input, FnArg, ItemFn, Pat};
38
39/// Mark a function as a GPU kernel entry point.
40///
41/// This attribute transforms the function signature for PTX compilation:
42/// - Adds `#[no_mangle]` for symbol visibility in PTX
43/// - Adds `extern "ptx-kernel"` ABI
44/// - Wraps the body in `unsafe` (PTX kernels are inherently unsafe)
45///
46/// # Parameter Rules
47///
48/// Kernel parameters must be one of:
49/// - Raw pointers (`*const T`, `*mut T`) — for device memory
50/// - Scalars (`u32`, `i32`, `f32`, `u64`, `i64`, `f64`, `bool`) — passed by value
51///
52/// # Compile-Time Safety
53///
54/// The function body uses warp-types normally. `Warp::kernel_entry()` creates
55/// the initial `Warp<All>`, and the type system prevents shuffle-from-inactive-lane
56/// bugs at compile time — on the actual GPU target.
57#[proc_macro_attribute]
58pub fn warp_kernel(_attr: TokenStream, item: TokenStream) -> TokenStream {
59    let input = parse_macro_input!(item as ItemFn);
60
61    let name = &input.sig.ident;
62    let params = &input.sig.inputs;
63    let body = &input.block;
64    let vis = &input.vis;
65
66    // Validate parameters: must be raw pointers or scalars
67    for param in params.iter() {
68        if let FnArg::Typed(pat_type) = param {
69            if let Err(err) = validate_kernel_param(&pat_type.ty, &pat_type.pat) {
70                return err;
71            }
72        }
73    }
74
75    // Generate the kernel function for nvptx64
76    let expanded = quote! {
77        #[no_mangle]
78        #vis unsafe extern "ptx-kernel" fn #name(#params) #body
79    };
80
81    TokenStream::from(expanded)
82}
83
84/// Validate that a kernel parameter type is GPU-compatible.
85///
86/// Returns `Ok(())` if valid, `Err(TokenStream)` with a `compile_error!` if not.
87fn validate_kernel_param(ty: &syn::Type, pat: &Pat) -> Result<(), TokenStream> {
88    match ty {
89        // Raw pointers are always OK
90        syn::Type::Ptr(_) => Ok(()),
91        // Path types: check if they're known scalars
92        syn::Type::Path(tp) => {
93            if let Some(seg) = tp.path.segments.last() {
94                let name = seg.ident.to_string();
95                let valid_scalars = [
96                    "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize", "f32",
97                    "f64", "bool",
98                ];
99                if !valid_scalars.contains(&name.as_str()) {
100                    let msg = format!(
101                        "warp_kernel: parameter `{}` has type `{}` which is not a GPU-compatible type. \
102                         Use raw pointers (*const T, *mut T) for device memory or scalar types (u32, i32, f32, etc.).",
103                        quote!(#pat), name
104                    );
105                    return Err(syn::Error::new_spanned(ty, msg).to_compile_error().into());
106                }
107            }
108            Ok(())
109        }
110        _ => {
111            let msg = format!(
112                "warp_kernel: parameter `{}` has unsupported type `{}`. \
113                 Kernel parameters must be raw pointers or scalar types.",
114                quote!(#pat),
115                quote!(#ty)
116            );
117            Err(syn::Error::new_spanned(ty, msg).to_compile_error().into())
118        }
119    }
120}