Skip to main content

warp_types_builder/
lib.rs

1//! Build-time PTX compilation for warp-types GPU kernels.
2//!
3//! Use in your `build.rs` to cross-compile a kernel crate to PTX,
4//! then load the generated PTX at runtime via cudarc.
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! // build.rs
10//! warp_types_builder::WarpBuilder::new("my-kernels")
11//!     .build()
12//!     .expect("Failed to compile GPU kernels");
13//! ```
14//!
15//! Then in your main crate:
16//!
17//! ```rust,ignore
18//! // src/main.rs
19//! mod kernels {
20//!     include!(concat!(env!("OUT_DIR"), "/kernels.rs"));
21//! }
22//!
23//! fn main() {
24//!     let ctx = cudarc::driver::CudaContext::new(0).unwrap();
25//!     let k = kernels::Kernels::load(&ctx).unwrap();
26//!     // k.butterfly_reduce — CudaFunction handle ready for launch
27//! }
28//! ```
29
30use std::env;
31use std::path::{Path, PathBuf};
32use std::process::Command;
33
34/// GPU target for cross-compilation.
35#[derive(Clone, Debug)]
36pub enum GpuTarget {
37    /// NVIDIA nvptx64 (32-lane warps, PTX output)
38    Nvidia,
39    /// AMD amdgcn (64-lane wavefronts, AMDGPU output)
40    Amd,
41}
42
43impl GpuTarget {
44    fn triple(&self) -> &str {
45        match self {
46            GpuTarget::Nvidia => "nvptx64-nvidia-cuda",
47            GpuTarget::Amd => "amdgcn-amd-amdhsa",
48        }
49    }
50
51    #[allow(dead_code)]
52    fn asm_extension(&self) -> &str {
53        match self {
54            GpuTarget::Nvidia => "s", // PTX assembly
55            GpuTarget::Amd => "s",    // GCN assembly
56        }
57    }
58}
59
60/// Builder for cross-compiling kernel crates to GPU assembly.
61pub struct WarpBuilder {
62    /// Path to the kernel crate (relative to the manifest dir).
63    kernel_crate: PathBuf,
64    /// Rust toolchain to use (default: "nightly").
65    toolchain: String,
66    /// Release mode (default: true).
67    release: bool,
68    /// Feature flags to pass to the kernel crate.
69    features: Vec<String>,
70    /// GPU target (default: NVIDIA).
71    target: GpuTarget,
72}
73
74impl WarpBuilder {
75    /// Create a new builder pointing at a kernel crate directory.
76    ///
77    /// The path is relative to the directory containing the host crate's `Cargo.toml`
78    /// (i.e., `CARGO_MANIFEST_DIR`).
79    pub fn new(kernel_crate_path: impl Into<PathBuf>) -> Self {
80        WarpBuilder {
81            kernel_crate: kernel_crate_path.into(),
82            toolchain: "nightly".to_string(),
83            release: true,
84            features: Vec::new(),
85            target: GpuTarget::Nvidia,
86        }
87    }
88
89    /// Set the Rust toolchain (default: "nightly").
90    pub fn toolchain(mut self, toolchain: impl Into<String>) -> Self {
91        self.toolchain = toolchain.into();
92        self
93    }
94
95    /// Disable release mode (compile in debug mode).
96    pub fn debug(mut self) -> Self {
97        self.release = false;
98        self
99    }
100
101    /// Set the GPU target (default: NVIDIA).
102    pub fn target(mut self, target: GpuTarget) -> Self {
103        self.target = target;
104        self
105    }
106
107    /// Enable a feature flag on the kernel crate.
108    pub fn feature(mut self, feature: impl Into<String>) -> Self {
109        self.features.push(feature.into());
110        self
111    }
112
113    /// Build the kernel crate, producing PTX and generating a Rust module.
114    ///
115    /// On success, writes to `OUT_DIR`:
116    /// - `kernels.ptx` — the raw PTX assembly
117    /// - `kernels.rs` — a Rust module with `KERNEL_PTX` constant and a `Kernels` struct
118    ///   that provides named `CudaFunction` handles for each kernel entry point
119    ///
120    /// Also prints `cargo:rerun-if-changed` for all kernel source files (recursive).
121    pub fn build(self) -> Result<BuildResult, BuildError> {
122        let manifest_dir =
123            env::var("CARGO_MANIFEST_DIR").map_err(|_| BuildError::NotInBuildScript)?;
124        let out_dir = env::var("OUT_DIR").map_err(|_| BuildError::NotInBuildScript)?;
125
126        let kernel_dir = Path::new(&manifest_dir).join(&self.kernel_crate);
127        if !kernel_dir.exists() {
128            return Err(BuildError::KernelCrateNotFound(kernel_dir));
129        }
130
131        // Tell cargo to rerun if ANY kernel source file changes (recursive)
132        emit_rerun_if_changed(&kernel_dir);
133
134        // Invoke cargo rustc for nvptx64 with --emit=asm to get PTX output.
135        // Use RUSTUP_TOOLCHAIN env var instead of +nightly syntax, because
136        // the +toolchain syntax requires rustup's proxy and doesn't work
137        // when cargo is invoked from within a build script.
138        let mut cmd = Command::new("cargo");
139        cmd.arg("rustc")
140            .arg("--target")
141            .arg(self.target.triple())
142            .arg("-Z")
143            .arg("build-std=core")
144            .current_dir(&kernel_dir);
145
146        if self.release {
147            cmd.arg("--release");
148        }
149
150        // Pass feature flags to the kernel crate
151        for feat in &self.features {
152            cmd.arg("--features").arg(feat);
153        }
154
155        // After `--`, pass rustc flags: emit assembly (PTX for nvptx64)
156        cmd.arg("--").arg("--emit=asm");
157
158        // Select nightly toolchain via env var (works inside build scripts).
159        // CRITICAL: remove RUSTC — the parent cargo sets it to the absolute path
160        // of its own rustc (e.g., stable's rustc), which the inner cargo would
161        // inherit and use directly, bypassing toolchain selection entirely.
162        // This is the same fix that spirv-builder uses.
163        cmd.env("RUSTUP_TOOLCHAIN", &self.toolchain);
164        cmd.env_remove("RUSTC");
165        cmd.env("RUSTFLAGS", "--cfg warp_kernel_build");
166
167        let output = cmd
168            .output()
169            .map_err(|e| BuildError::CargoFailed(format!("Failed to run cargo: {}", e)))?;
170
171        if !output.status.success() {
172            let stderr = String::from_utf8_lossy(&output.stderr);
173            return Err(BuildError::CompilationFailed(stderr.into_owned()));
174        }
175
176        // Find the generated PTX/assembly file
177        let profile = if self.release { "release" } else { "debug" };
178        let target_dir = kernel_dir
179            .join("target")
180            .join(self.target.triple())
181            .join(profile);
182
183        let ptx_path = find_ptx_file(&target_dir, &kernel_dir)?;
184
185        // Read PTX content
186        let ptx_content = std::fs::read_to_string(&ptx_path)
187            .map_err(|e| BuildError::PtxReadFailed(format!("{}: {}", ptx_path.display(), e)))?;
188
189        // Parse kernel entry points from PTX
190        let kernels = parse_kernel_entries(&ptx_content);
191
192        // Write PTX to OUT_DIR
193        let out_ptx = Path::new(&out_dir).join("kernels.ptx");
194        std::fs::write(&out_ptx, &ptx_content)
195            .map_err(|e| BuildError::WriteFailed(format!("{}: {}", out_ptx.display(), e)))?;
196
197        // Generate Rust module with Kernels struct
198        let out_rs = Path::new(&out_dir).join("kernels.rs");
199        let rs_content = generate_rust_module(&self.kernel_crate, &kernels);
200        std::fs::write(&out_rs, &rs_content)
201            .map_err(|e| BuildError::WriteFailed(format!("{}: {}", out_rs.display(), e)))?;
202
203        Ok(BuildResult {
204            ptx_path: out_ptx,
205            module_path: out_rs,
206            kernel_names: kernels,
207        })
208    }
209}
210
211/// Result of a successful build.
212pub struct BuildResult {
213    /// Path to the generated PTX file in OUT_DIR.
214    pub ptx_path: PathBuf,
215    /// Path to the generated Rust module in OUT_DIR.
216    pub module_path: PathBuf,
217    /// Names of kernel entry points found in the PTX.
218    pub kernel_names: Vec<String>,
219}
220
221/// Errors that can occur during kernel compilation.
222#[derive(Debug)]
223pub enum BuildError {
224    /// Not running inside a build script (CARGO_MANIFEST_DIR not set).
225    NotInBuildScript,
226    /// Kernel crate directory not found.
227    KernelCrateNotFound(PathBuf),
228    /// cargo build failed.
229    CargoFailed(String),
230    /// Kernel crate compilation failed.
231    CompilationFailed(String),
232    /// Could not find generated PTX file.
233    PtxNotFound(String),
234    /// Could not read PTX file.
235    PtxReadFailed(String),
236    /// Could not write output files.
237    WriteFailed(String),
238}
239
240impl std::fmt::Display for BuildError {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        match self {
243            BuildError::NotInBuildScript => write!(
244                f,
245                "Not running in a build script (CARGO_MANIFEST_DIR not set)"
246            ),
247            BuildError::KernelCrateNotFound(p) => {
248                write!(f, "Kernel crate not found: {}", p.display())
249            }
250            BuildError::CargoFailed(s) => write!(f, "Cargo invocation failed: {}", s),
251            BuildError::CompilationFailed(s) => write!(f, "Kernel compilation failed:\n{}", s),
252            BuildError::PtxNotFound(s) => write!(f, "PTX file not found: {}", s),
253            BuildError::PtxReadFailed(s) => write!(f, "Failed to read PTX: {}", s),
254            BuildError::WriteFailed(s) => write!(f, "Failed to write output: {}", s),
255        }
256    }
257}
258
259impl std::error::Error for BuildError {}
260
261// ============================================================================
262// PTX parsing
263// ============================================================================
264
265/// Parse `.visible .entry <name>(` lines from PTX to find kernel entry points.
266fn parse_kernel_entries(ptx: &str) -> Vec<String> {
267    ptx.lines()
268        .filter_map(|line| {
269            let trimmed = line.trim();
270            if trimmed.starts_with(".visible .entry ") {
271                // Format: .visible .entry kernel_name(
272                let rest = trimmed.strip_prefix(".visible .entry ")?;
273                let name = rest.split('(').next()?.trim();
274                if !name.is_empty() {
275                    Some(name.to_string())
276                } else {
277                    None
278                }
279            } else {
280                None
281            }
282        })
283        .collect()
284}
285
286// ============================================================================
287// Code generation
288// ============================================================================
289
290/// Generate a Rust module with PTX constant and Kernels struct.
291fn generate_rust_module(kernel_crate: &Path, kernels: &[String]) -> String {
292    let mut code = String::new();
293
294    // Header
295    code.push_str(&format!(
296        "// Auto-generated by warp-types-builder. Do not edit.\n\
297         // Kernel crate: {}\n\
298         // Kernel count: {}\n\n",
299        kernel_crate.display(),
300        kernels.len(),
301    ));
302
303    // PTX constant
304    code.push_str(
305        "/// Raw PTX assembly for all kernels in this module.\n\
306         pub const KERNEL_PTX: &str = include_str!(concat!(env!(\"OUT_DIR\"), \"/kernels.ptx\"));\n\n"
307    );
308
309    // Kernels struct
310    code.push_str(
311        "/// Loaded GPU kernel functions.\n\
312         ///\n\
313         /// Created by [`Kernels::load`], which parses the PTX and extracts\n\
314         /// each kernel entry point as a ready-to-launch [`CudaFunction`].\n\
315         ///\n\
316         /// # Available kernels\n",
317    );
318    for name in kernels {
319        code.push_str(&format!("/// - `{}` \n", name));
320    }
321    code.push_str(
322        "pub struct Kernels {\n\
323         _module: ::std::sync::Arc<::cudarc::driver::CudaModule>,\n",
324    );
325    for name in kernels {
326        code.push_str(&format!(
327            "    /// Kernel: `{name}`\n\
328                 pub {name}: ::cudarc::driver::CudaFunction,\n",
329            name = name,
330        ));
331    }
332    code.push_str("}\n\n");
333
334    // Kernels::load impl
335    code.push_str(
336        "impl Kernels {\n\
337         /// Load all kernels from the compiled PTX.\n\
338         ///\n\
339         /// Parses the embedded PTX assembly, loads it as a CUDA module,\n\
340         /// and extracts each kernel entry point by name.\n\
341         pub fn load(ctx: &::std::sync::Arc<::cudarc::driver::CudaContext>) -> \
342             ::std::result::Result<Self, Box<dyn ::std::error::Error>> {\n\
343             let ptx = ::cudarc::nvrtc::Ptx::from_src(KERNEL_PTX.to_string());\n\
344             let module = ctx.load_module(ptx)?;\n",
345    );
346    for name in kernels {
347        code.push_str(&format!(
348            "        let {name} = module.load_function(\"{name}\")?;\n",
349            name = name,
350        ));
351    }
352    code.push_str("        let _module = module;\n");
353    code.push_str("        Ok(Kernels {\n            _module,\n");
354    for name in kernels {
355        code.push_str(&format!("            {},\n", name));
356    }
357    code.push_str("        })\n    }\n}\n");
358
359    code
360}
361
362// ============================================================================
363// File watching
364// ============================================================================
365
366/// Emit `cargo:rerun-if-changed` for all files in the kernel crate recursively.
367fn emit_rerun_if_changed(kernel_dir: &Path) {
368    println!(
369        "cargo:rerun-if-changed={}",
370        kernel_dir.join("Cargo.toml").display()
371    );
372
373    let src_dir = kernel_dir.join("src");
374    if src_dir.exists() {
375        emit_rerun_recursive(&src_dir);
376    }
377}
378
379fn emit_rerun_recursive(dir: &Path) {
380    if let Ok(entries) = std::fs::read_dir(dir) {
381        for entry in entries.flatten() {
382            let path = entry.path();
383            if path.is_dir() {
384                emit_rerun_recursive(&path);
385            } else {
386                println!("cargo:rerun-if-changed={}", path.display());
387            }
388        }
389    }
390}
391
392// ============================================================================
393// PTX file discovery
394// ============================================================================
395
396/// Search for the PTX (.s) file in the target directory.
397fn find_ptx_file(target_dir: &Path, kernel_dir: &Path) -> Result<PathBuf, BuildError> {
398    // Get the crate name from Cargo.toml
399    let cargo_toml = kernel_dir.join("Cargo.toml");
400    let content = std::fs::read_to_string(&cargo_toml).map_err(|e| {
401        BuildError::PtxNotFound(format!("Can't read {}: {}", cargo_toml.display(), e))
402    })?;
403
404    // Simple TOML parsing — find name = "..."
405    let crate_name = content
406        .lines()
407        .find_map(|line| {
408            let line = line.trim();
409            if line.starts_with("name") {
410                let val = line.split('=').nth(1)?.trim().trim_matches('"');
411                Some(val.replace('-', "_"))
412            } else {
413                None
414            }
415        })
416        .unwrap_or_else(|| "kernels".to_string());
417
418    // Check common locations for the .s file
419    let candidates = [
420        target_dir.join(format!("{}.s", crate_name)),
421        target_dir.join(format!("lib{}.s", crate_name)),
422        target_dir.join(format!("{}.ptx", crate_name)),
423        target_dir.join("deps").join(format!("{}.s", crate_name)),
424        target_dir.join("deps").join(format!("lib{}.s", crate_name)),
425    ];
426
427    for path in &candidates {
428        if path.exists() {
429            return Ok(path.clone());
430        }
431    }
432
433    // Fallback: search deps/ for .s files matching the crate name pattern
434    let deps = target_dir.join("deps");
435    if let Ok(entries) = std::fs::read_dir(&deps) {
436        for entry in entries.flatten() {
437            let p = entry.path();
438            if p.extension().is_some_and(|e| e == "s") {
439                let fname = p
440                    .file_stem()
441                    .map(|s| s.to_string_lossy().to_string())
442                    .unwrap_or_default();
443                // Match crate_name-HASH.s pattern, skip core/compiler_builtins
444                if fname.starts_with(&crate_name)
445                    && !fname.starts_with("core-")
446                    && !fname.starts_with("compiler_builtins-")
447                {
448                    return Ok(p);
449                }
450            }
451        }
452    }
453
454    // Last resort: any non-core .s file in deps/
455    if let Ok(entries) = std::fs::read_dir(&deps) {
456        for entry in entries.flatten() {
457            let p = entry.path();
458            if p.extension().is_some_and(|e| e == "s") {
459                let fname = p
460                    .file_stem()
461                    .map(|s| s.to_string_lossy().to_string())
462                    .unwrap_or_default();
463                if !fname.starts_with("core-")
464                    && !fname.starts_with("compiler_builtins-")
465                    && !fname.starts_with("warp_types-")
466                {
467                    return Ok(p);
468                }
469            }
470        }
471    }
472
473    Err(BuildError::PtxNotFound(format!(
474        "No .s/.ptx file found in {}. Crate name: '{}'. Checked: {:?}",
475        target_dir.display(),
476        crate_name,
477        candidates
478            .iter()
479            .map(|c| c.display().to_string())
480            .collect::<Vec<_>>()
481    )))
482}