Skip to main content

rustkernel_derive/
lib.rs

1//! Procedural macros for RustKernels.
2//!
3//! This crate provides the following macros:
4//! - `#[gpu_kernel]` - Define a GPU kernel with metadata
5//! - `#[derive(KernelMessage)]` - Derive serialization for kernel messages
6//! - `#[kernel_state]` - Mark types as GPU-compatible kernel state
7//!
8//! For low-level ring kernel macros, see `ringkernel-derive` 0.4.2 which provides:
9//! - `#[derive(RingMessage)]` - Ring message serialization with domain-based type IDs
10//! - `#[derive(PersistentMessage)]` - CUDA persistent message dispatch
11//! - `#[derive(ControlBlockState)]` - Embedded state for GPU ControlBlocks
12//! - `#[derive(GpuType)]` - Pod+Zeroable for GPU data transfer
13//! - `#[ring_kernel]` - Ring kernel handler generation
14//! - `#[stencil_kernel]` - CUDA stencil pattern kernels
15//!
16//! # Example
17//!
18//! ```ignore
19//! use rustkernel_derive::gpu_kernel;
20//!
21//! #[gpu_kernel(
22//!     id = "graph/pagerank",
23//!     mode = "ring",
24//!     domain = "GraphAnalytics",
25//!     throughput = 100_000,
26//!     latency_us = 1.0
27//! )]
28//! pub async fn pagerank_kernel(
29//!     ctx: &mut RingContext,
30//!     request: PageRankRequest,
31//! ) -> PageRankResponse {
32//!     // Implementation
33//! }
34//! ```
35
36use darling::{FromDeriveInput, FromMeta};
37use proc_macro::TokenStream;
38use quote::quote;
39use syn::{DeriveInput, ItemFn, parse_macro_input};
40
41/// Arguments for the `#[gpu_kernel]` attribute.
42#[derive(Debug, FromMeta)]
43struct GpuKernelArgs {
44    /// Kernel ID (e.g., "graph/pagerank").
45    id: String,
46
47    /// Kernel mode: "batch" or "ring".
48    mode: String,
49
50    /// Domain name (e.g., "GraphAnalytics").
51    domain: String,
52
53    /// Description (optional).
54    #[darling(default)]
55    description: Option<String>,
56
57    /// Expected throughput in ops/sec (optional).
58    #[darling(default)]
59    throughput: Option<u64>,
60
61    /// Target latency in microseconds (optional).
62    #[darling(default)]
63    latency_us: Option<f64>,
64
65    /// Whether GPU-native execution is required (optional).
66    #[darling(default)]
67    gpu_native: Option<bool>,
68}
69
70/// Define a GPU kernel with metadata.
71///
72/// This attribute generates a kernel struct and implements the necessary traits.
73///
74/// # Attributes
75///
76/// - `id` - Unique kernel identifier (required)
77/// - `mode` - Kernel mode: "batch" or "ring" (required)
78/// - `domain` - Business domain (required)
79/// - `description` - Human-readable description (optional)
80/// - `throughput` - Expected throughput in ops/sec (optional)
81/// - `latency_us` - Target latency in microseconds (optional)
82/// - `gpu_native` - Whether GPU-native execution is required (optional)
83///
84/// # Example
85///
86/// ```ignore
87/// #[gpu_kernel(
88///     id = "graph/pagerank",
89///     mode = "ring",
90///     domain = "GraphAnalytics",
91///     description = "PageRank centrality calculation",
92///     throughput = 100_000,
93///     latency_us = 1.0,
94///     gpu_native = true
95/// )]
96/// pub async fn pagerank(ctx: &mut RingContext, req: PageRankRequest) -> PageRankResponse {
97///     // Implementation
98/// }
99/// ```
100#[proc_macro_attribute]
101pub fn gpu_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
102    let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
103        Ok(v) => v,
104        Err(e) => return TokenStream::from(e.to_compile_error()),
105    };
106
107    let args = match GpuKernelArgs::from_list(&args) {
108        Ok(v) => v,
109        Err(e) => return TokenStream::from(e.write_errors()),
110    };
111
112    let input = parse_macro_input!(item as ItemFn);
113    let fn_name = &input.sig.ident;
114    let fn_vis = &input.vis;
115    let fn_block = &input.block;
116    let fn_inputs = &input.sig.inputs;
117    let fn_output = &input.sig.output;
118    let fn_asyncness = &input.sig.asyncness;
119
120    // Generate struct name from function name (PascalCase)
121    let struct_name = to_pascal_case(&fn_name.to_string());
122    let struct_ident = syn::Ident::new(&struct_name, fn_name.span());
123
124    // Parse mode
125    let mode = match args.mode.as_str() {
126        "batch" => quote! { rustkernel_core::kernel::KernelMode::Batch },
127        "ring" => quote! { rustkernel_core::kernel::KernelMode::Ring },
128        _ => {
129            return syn::Error::new_spanned(&input.sig, "mode must be 'batch' or 'ring'")
130                .to_compile_error()
131                .into();
132        }
133    };
134
135    // Parse domain
136    let domain = &args.domain;
137    let domain_ident = syn::Ident::new(domain, proc_macro2::Span::call_site());
138
139    // Default values
140    let description = args.description.unwrap_or_default();
141    let throughput = args.throughput.unwrap_or(10_000);
142    let latency_us = args.latency_us.unwrap_or(50.0);
143    let gpu_native = args.gpu_native.unwrap_or(false);
144    let kernel_id = &args.id;
145
146    // Generate the kernel struct and implementation
147    let expanded = quote! {
148        /// Generated kernel struct for #fn_name.
149        #[derive(Debug, Clone)]
150        #fn_vis struct #struct_ident {
151            metadata: rustkernel_core::kernel::KernelMetadata,
152        }
153
154        impl #struct_ident {
155            /// Create a new instance of this kernel.
156            #[must_use]
157            pub fn new() -> Self {
158                Self {
159                    metadata: rustkernel_core::kernel::KernelMetadata {
160                        id: #kernel_id.to_string(),
161                        mode: #mode,
162                        domain: rustkernel_core::domain::Domain::#domain_ident,
163                        description: #description.to_string(),
164                        expected_throughput: #throughput,
165                        target_latency_us: #latency_us,
166                        requires_gpu_native: #gpu_native,
167                        version: 1,
168                    },
169                }
170            }
171        }
172
173        impl Default for #struct_ident {
174            fn default() -> Self {
175                Self::new()
176            }
177        }
178
179        impl rustkernel_core::traits::GpuKernel for #struct_ident {
180            fn metadata(&self) -> &rustkernel_core::kernel::KernelMetadata {
181                &self.metadata
182            }
183        }
184
185        // Keep the original function for implementation
186        #fn_vis #fn_asyncness fn #fn_name(#fn_inputs) #fn_output
187        #fn_block
188    };
189
190    TokenStream::from(expanded)
191}
192
193/// Convert a snake_case string to PascalCase.
194fn to_pascal_case(s: &str) -> String {
195    s.split('_')
196        .filter(|part| !part.is_empty())
197        .map(|part| {
198            let mut chars = part.chars();
199            match chars.next() {
200                Some(first) => first.to_uppercase().chain(chars).collect::<String>(),
201                None => String::new(),
202            }
203        })
204        .collect()
205}
206
207/// Arguments for `#[derive(KernelMessage)]`.
208#[derive(Debug, FromDeriveInput)]
209#[darling(attributes(message))]
210struct KernelMessageArgs {
211    ident: syn::Ident,
212    generics: syn::Generics,
213
214    /// Message type ID.
215    #[darling(default)]
216    type_id: Option<u64>,
217
218    /// Domain for the message (reserved for future use).
219    #[darling(default)]
220    #[allow(dead_code)]
221    domain: Option<String>,
222}
223
224/// Derive macro for kernel messages.
225///
226/// This generates implementations for the `BatchMessage` trait, providing
227/// serialization and type information for batch kernel messages.
228///
229/// # Attributes
230///
231/// - `type_id` - Unique message type identifier (optional, defaults to hash of type name)
232/// - `domain` - Domain for the message (optional)
233///
234/// # Example
235///
236/// ```ignore
237/// #[derive(Debug, Clone, Serialize, Deserialize, KernelMessage)]
238/// #[message(type_id = 100, domain = "GraphAnalytics")]
239/// pub struct PageRankInput {
240///     pub graph: CsrGraph,
241///     pub damping: f64,
242/// }
243/// ```
244///
245/// # Generated Implementation
246///
247/// The macro generates:
248/// - `BatchMessage` trait implementation with `message_type_id()`
249/// - `to_json()` and `from_json()` methods for JSON serialization
250/// - A `message_type_id()` associated function on the type itself
251#[proc_macro_derive(KernelMessage, attributes(message))]
252pub fn derive_kernel_message(input: TokenStream) -> TokenStream {
253    let input = parse_macro_input!(input as DeriveInput);
254
255    let args = match KernelMessageArgs::from_derive_input(&input) {
256        Ok(v) => v,
257        Err(e) => return TokenStream::from(e.write_errors()),
258    };
259
260    let name = args.ident;
261    let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
262
263    // Calculate type_id: use provided value or hash of type name
264    let type_id = args.type_id.unwrap_or_else(|| {
265        use std::collections::hash_map::DefaultHasher;
266        use std::hash::{Hash, Hasher};
267        let mut hasher = DefaultHasher::new();
268        name.to_string().hash(&mut hasher);
269        hasher.finish()
270    });
271
272    let expanded = quote! {
273        // Associated function for direct access
274        impl #impl_generics #name #ty_generics #where_clause {
275            /// Get the message type ID.
276            #[must_use]
277            pub const fn message_type_id() -> u64 {
278                #type_id
279            }
280        }
281
282        // Implement BatchMessage trait for batch kernel communication
283        impl #impl_generics ::rustkernel_core::messages::BatchMessage for #name #ty_generics #where_clause {
284            fn message_type_id() -> u64 {
285                #type_id
286            }
287        }
288    };
289
290    TokenStream::from(expanded)
291}
292
293/// Attribute for marking kernel state types.
294///
295/// This ensures the type meets GPU requirements (unmanaged, fixed layout).
296///
297/// # Example
298///
299/// ```ignore
300/// #[kernel_state(size = 256)]
301/// pub struct PageRankState {
302///     pub scores: [f32; 64],
303/// }
304/// ```
305#[proc_macro_attribute]
306pub fn kernel_state(_attr: TokenStream, item: TokenStream) -> TokenStream {
307    // For now, just pass through - state validation can be added later
308    let input = parse_macro_input!(item as DeriveInput);
309
310    let expanded = quote! {
311        #[repr(C)]
312        #[derive(Clone, Copy, Debug, Default)]
313        #input
314    };
315
316    TokenStream::from(expanded)
317}