warp_types_kernel/lib.rs
1//! Proc macro for marking GPU kernel functions.
2//!
3//! `#[warp_kernel]` transforms a function into a proper PTX kernel entry point
4//! when compiling for nvptx64, and generates a host-side launcher when compiling
5//! for the host target.
6//!
7//! # Usage
8//!
9//! In your kernel crate (compiled for nvptx64):
10//!
11//! ```rust,ignore
12//! use warp_types::prelude::*;
13//! use warp_types_kernel::warp_kernel;
14//!
15//! #[warp_kernel]
16//! pub fn butterfly_reduce(data: *mut i32) {
17//! let warp: Warp<All> = Warp::kernel_entry();
18//! let tid = warp_types::gpu::thread_id_x();
19//! let mut val = unsafe { *data.add(tid as usize) };
20//!
21//! val += warp.shuffle_xor(PerLane::new(val), 16).get();
22//! val += warp.shuffle_xor(PerLane::new(val), 8).get();
23//! val += warp.shuffle_xor(PerLane::new(val), 4).get();
24//! val += warp.shuffle_xor(PerLane::new(val), 2).get();
25//! val += warp.shuffle_xor(PerLane::new(val), 1).get();
26//!
27//! unsafe { *data.add(tid as usize) = val; }
28//! }
29//! ```
30//!
31//! The macro emits:
32//! - On nvptx64: `#[no_mangle] pub unsafe extern "ptx-kernel" fn butterfly_reduce(...)`
33//! - On host: nothing (kernel functions are only compiled for GPU)
34
35use proc_macro::TokenStream;
36use quote::quote;
37use syn::{parse_macro_input, FnArg, ItemFn, Pat};
38
39/// Mark a function as a GPU kernel entry point.
40///
41/// This attribute transforms the function signature for PTX compilation:
42/// - Adds `#[no_mangle]` for symbol visibility in PTX
43/// - Adds `extern "ptx-kernel"` ABI
44/// - Wraps the body in `unsafe` (PTX kernels are inherently unsafe)
45///
46/// # Parameter Rules
47///
48/// Kernel parameters must be one of:
49/// - Raw pointers (`*const T`, `*mut T`) — for device memory
50/// - Scalars (`u32`, `i32`, `f32`, `u64`, `i64`, `f64`, `bool`) — passed by value
51///
52/// # Compile-Time Safety
53///
54/// The function body uses warp-types normally. `Warp::kernel_entry()` creates
55/// the initial `Warp<All>`, and the type system prevents shuffle-from-inactive-lane
56/// bugs at compile time — on the actual GPU target.
57#[proc_macro_attribute]
58pub fn warp_kernel(_attr: TokenStream, item: TokenStream) -> TokenStream {
59 let input = parse_macro_input!(item as ItemFn);
60
61 let name = &input.sig.ident;
62 let params = &input.sig.inputs;
63 let body = &input.block;
64 let vis = &input.vis;
65
66 // Validate parameters: must be raw pointers or scalars
67 for param in params.iter() {
68 if let FnArg::Typed(pat_type) = param {
69 if let Err(err) = validate_kernel_param(&pat_type.ty, &pat_type.pat) {
70 return err;
71 }
72 }
73 }
74
75 // Generate the kernel function for nvptx64
76 let expanded = quote! {
77 #[no_mangle]
78 #vis unsafe extern "ptx-kernel" fn #name(#params) #body
79 };
80
81 TokenStream::from(expanded)
82}
83
84/// Validate that a kernel parameter type is GPU-compatible.
85///
86/// Returns `Ok(())` if valid, `Err(TokenStream)` with a `compile_error!` if not.
87fn validate_kernel_param(ty: &syn::Type, pat: &Pat) -> Result<(), TokenStream> {
88 match ty {
89 // Raw pointers are always OK
90 syn::Type::Ptr(_) => Ok(()),
91 // Path types: check if they're known scalars
92 syn::Type::Path(tp) => {
93 if let Some(seg) = tp.path.segments.last() {
94 let name = seg.ident.to_string();
95 let valid_scalars = [
96 "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize", "f32",
97 "f64", "bool",
98 ];
99 if !valid_scalars.contains(&name.as_str()) {
100 let msg = format!(
101 "warp_kernel: parameter `{}` has type `{}` which is not a GPU-compatible type. \
102 Use raw pointers (*const T, *mut T) for device memory or scalar types (u32, i32, f32, etc.).",
103 quote!(#pat), name
104 );
105 return Err(syn::Error::new_spanned(ty, msg).to_compile_error().into());
106 }
107 }
108 Ok(())
109 }
110 _ => {
111 let msg = format!(
112 "warp_kernel: parameter `{}` has unsupported type `{}`. \
113 Kernel parameters must be raw pointers or scalar types.",
114 quote!(#pat),
115 quote!(#ty)
116 );
117 Err(syn::Error::new_spanned(ty, msg).to_compile_error().into())
118 }
119 }
120}