Skip to main content

warp_types_builder/
lib.rs

1//! Build-time PTX compilation for warp-types GPU kernels.
2//!
3//! Use in your `build.rs` to cross-compile a kernel crate to PTX,
4//! then load the generated PTX at runtime via cudarc.
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! // build.rs
10//! warp_types_builder::WarpBuilder::new("my-kernels")
11//!     .build()
12//!     .expect("Failed to compile GPU kernels");
13//! ```
14//!
15//! Then in your main crate:
16//!
17//! ```rust,ignore
18//! // src/main.rs
19//! mod kernels {
20//!     include!(concat!(env!("OUT_DIR"), "/kernels.rs"));
21//! }
22//!
23//! fn main() {
24//!     let ctx = cudarc::driver::CudaContext::new(0).unwrap();
25//!     let k = kernels::Kernels::load(&ctx).unwrap();
26//!     // k.butterfly_reduce — CudaFunction handle ready for launch
27//! }
28//! ```
29
30use std::env;
31use std::path::{Path, PathBuf};
32use std::process::Command;
33
34/// GPU target for cross-compilation.
35#[derive(Clone, Debug)]
36pub enum GpuTarget {
37    /// NVIDIA nvptx64 (32-lane warps, PTX output)
38    Nvidia,
39    /// AMD amdgcn (64-lane wavefronts, AMDGPU output)
40    Amd,
41}
42
43impl GpuTarget {
44    fn triple(&self) -> &str {
45        match self {
46            GpuTarget::Nvidia => "nvptx64-nvidia-cuda",
47            GpuTarget::Amd => "amdgcn-amd-amdhsa",
48        }
49    }
50
51    #[allow(dead_code)]
52    fn asm_extension(&self) -> &str {
53        match self {
54            GpuTarget::Nvidia => "s", // PTX assembly
55            GpuTarget::Amd => "s",    // GCN assembly
56        }
57    }
58}
59
60/// Builder for cross-compiling kernel crates to GPU assembly.
61pub struct WarpBuilder {
62    /// Path to the kernel crate (relative to the manifest dir).
63    kernel_crate: PathBuf,
64    /// Rust toolchain to use (default: "nightly").
65    toolchain: String,
66    /// Release mode (default: true).
67    release: bool,
68    /// Feature flags to pass to the kernel crate.
69    features: Vec<String>,
70    /// GPU target (default: NVIDIA).
71    target: GpuTarget,
72    /// NVIDIA SM architecture (default: "sm_70").
73    ///
74    /// Controls the PTX ISA version emitted by LLVM. Must be at least `sm_70`
75    /// (Volta) because warp-types uses `shfl.sync.*` which requires PTX 6.0+.
76    /// Higher values enable newer instructions and may produce better code.
77    /// Common values: sm_70 (Volta), sm_80 (Ampere), sm_89 (Ada), sm_90 (Hopper).
78    sm_arch: String,
79}
80
81impl WarpBuilder {
82    /// Create a new builder pointing at a kernel crate directory.
83    ///
84    /// The path is relative to the directory containing the host crate's `Cargo.toml`
85    /// (i.e., `CARGO_MANIFEST_DIR`).
86    pub fn new(kernel_crate_path: impl Into<PathBuf>) -> Self {
87        WarpBuilder {
88            kernel_crate: kernel_crate_path.into(),
89            toolchain: "nightly".to_string(),
90            release: true,
91            features: Vec::new(),
92            target: GpuTarget::Nvidia,
93            sm_arch: "sm_70".to_string(),
94        }
95    }
96
97    /// Set the Rust toolchain (default: "nightly").
98    pub fn toolchain(mut self, toolchain: impl Into<String>) -> Self {
99        self.toolchain = toolchain.into();
100        self
101    }
102
103    /// Disable release mode (compile in debug mode).
104    pub fn debug(mut self) -> Self {
105        self.release = false;
106        self
107    }
108
109    /// Set the GPU target (default: NVIDIA).
110    pub fn target(mut self, target: GpuTarget) -> Self {
111        self.target = target;
112        self
113    }
114
115    /// Set the NVIDIA SM architecture (default: "sm_70").
116    ///
117    /// Controls `-C target-cpu` passed to rustc, which determines the PTX ISA
118    /// version. Must be at least `sm_70` for `shfl.sync` support.
119    pub fn sm_arch(mut self, arch: impl Into<String>) -> Self {
120        self.sm_arch = arch.into();
121        self
122    }
123
124    /// Enable a feature flag on the kernel crate.
125    pub fn feature(mut self, feature: impl Into<String>) -> Self {
126        self.features.push(feature.into());
127        self
128    }
129
130    /// Build the kernel crate, producing PTX and generating a Rust module.
131    ///
132    /// On success, writes to `OUT_DIR`:
133    /// - `kernels.ptx` — the raw PTX assembly
134    /// - `kernels.rs` — a Rust module with `KERNEL_PTX` constant and a `Kernels` struct
135    ///   that provides named `CudaFunction` handles for each kernel entry point
136    ///
137    /// Also prints `cargo:rerun-if-changed` for all kernel source files (recursive).
138    pub fn build(self) -> Result<BuildResult, BuildError> {
139        let manifest_dir =
140            env::var("CARGO_MANIFEST_DIR").map_err(|_| BuildError::NotInBuildScript)?;
141        let out_dir = env::var("OUT_DIR").map_err(|_| BuildError::NotInBuildScript)?;
142
143        let kernel_dir = Path::new(&manifest_dir).join(&self.kernel_crate);
144        if !kernel_dir.exists() {
145            return Err(BuildError::KernelCrateNotFound(kernel_dir));
146        }
147
148        // Tell cargo to rerun if ANY kernel source file changes (recursive)
149        emit_rerun_if_changed(&kernel_dir);
150
151        // Invoke cargo rustc for nvptx64 with --emit=asm to get PTX output.
152        // Use RUSTUP_TOOLCHAIN env var instead of +nightly syntax, because
153        // the +toolchain syntax requires rustup's proxy and doesn't work
154        // when cargo is invoked from within a build script.
155        let mut cmd = Command::new("cargo");
156        cmd.arg("rustc")
157            .arg("--target")
158            .arg(self.target.triple())
159            .arg("-Z")
160            .arg("build-std=core")
161            .current_dir(&kernel_dir);
162
163        if self.release {
164            cmd.arg("--release");
165        }
166
167        // Pass feature flags to the kernel crate
168        for feat in &self.features {
169            cmd.arg("--features").arg(feat);
170        }
171
172        // After `--`, pass rustc flags: emit assembly (PTX for nvptx64)
173        // -C target-cpu sets the SM architecture, which determines the PTX ISA
174        // version. Without this, LLVM defaults to sm_30 / PTX 3.2, which lacks
175        // shfl.sync (requires PTX 6.0 / sm_70+).
176        cmd.arg("--")
177            .arg("--emit=asm")
178            .arg("-C")
179            .arg(format!("target-cpu={}", self.sm_arch));
180
181        // Select nightly toolchain via env var (works inside build scripts).
182        // CRITICAL: remove RUSTC — the parent cargo sets it to the absolute path
183        // of its own rustc (e.g., stable's rustc), which the inner cargo would
184        // inherit and use directly, bypassing toolchain selection entirely.
185        // This is the same fix that spirv-builder uses.
186        cmd.env("RUSTUP_TOOLCHAIN", &self.toolchain);
187        cmd.env_remove("RUSTC");
188        cmd.env("RUSTFLAGS", "--cfg warp_kernel_build");
189
190        let output = cmd
191            .output()
192            .map_err(|e| BuildError::CargoFailed(format!("Failed to run cargo: {}", e)))?;
193
194        if !output.status.success() {
195            let stderr = String::from_utf8_lossy(&output.stderr);
196            return Err(BuildError::CompilationFailed(stderr.into_owned()));
197        }
198
199        // Find the generated PTX/assembly file
200        let profile = if self.release { "release" } else { "debug" };
201        let target_dir = kernel_dir
202            .join("target")
203            .join(self.target.triple())
204            .join(profile);
205
206        let ptx_path = find_ptx_file(&target_dir, &kernel_dir)?;
207
208        // Read PTX content
209        let ptx_content = std::fs::read_to_string(&ptx_path)
210            .map_err(|e| BuildError::PtxReadFailed(format!("{}: {}", ptx_path.display(), e)))?;
211
212        // Parse kernel entry points from PTX
213        let kernels = parse_kernel_entries(&ptx_content);
214
215        // Write PTX to OUT_DIR
216        let out_ptx = Path::new(&out_dir).join("kernels.ptx");
217        std::fs::write(&out_ptx, &ptx_content)
218            .map_err(|e| BuildError::WriteFailed(format!("{}: {}", out_ptx.display(), e)))?;
219
220        // Generate Rust module with Kernels struct
221        let out_rs = Path::new(&out_dir).join("kernels.rs");
222        let rs_content = generate_rust_module(&self.kernel_crate, &kernels);
223        std::fs::write(&out_rs, &rs_content)
224            .map_err(|e| BuildError::WriteFailed(format!("{}: {}", out_rs.display(), e)))?;
225
226        Ok(BuildResult {
227            ptx_path: out_ptx,
228            module_path: out_rs,
229            kernel_names: kernels,
230        })
231    }
232}
233
234/// Result of a successful build.
235pub struct BuildResult {
236    /// Path to the generated PTX file in OUT_DIR.
237    pub ptx_path: PathBuf,
238    /// Path to the generated Rust module in OUT_DIR.
239    pub module_path: PathBuf,
240    /// Names of kernel entry points found in the PTX.
241    pub kernel_names: Vec<String>,
242}
243
244/// Errors that can occur during kernel compilation.
245#[derive(Debug)]
246pub enum BuildError {
247    /// Not running inside a build script (CARGO_MANIFEST_DIR not set).
248    NotInBuildScript,
249    /// Kernel crate directory not found.
250    KernelCrateNotFound(PathBuf),
251    /// cargo build failed.
252    CargoFailed(String),
253    /// Kernel crate compilation failed.
254    CompilationFailed(String),
255    /// Could not find generated PTX file.
256    PtxNotFound(String),
257    /// Could not read PTX file.
258    PtxReadFailed(String),
259    /// Could not write output files.
260    WriteFailed(String),
261}
262
263impl std::fmt::Display for BuildError {
264    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265        match self {
266            BuildError::NotInBuildScript => write!(
267                f,
268                "Not running in a build script (CARGO_MANIFEST_DIR not set)"
269            ),
270            BuildError::KernelCrateNotFound(p) => {
271                write!(f, "Kernel crate not found: {}", p.display())
272            }
273            BuildError::CargoFailed(s) => write!(f, "Cargo invocation failed: {}", s),
274            BuildError::CompilationFailed(s) => write!(f, "Kernel compilation failed:\n{}", s),
275            BuildError::PtxNotFound(s) => write!(f, "PTX file not found: {}", s),
276            BuildError::PtxReadFailed(s) => write!(f, "Failed to read PTX: {}", s),
277            BuildError::WriteFailed(s) => write!(f, "Failed to write output: {}", s),
278        }
279    }
280}
281
282impl std::error::Error for BuildError {}
283
284// ============================================================================
285// PTX parsing
286// ============================================================================
287
288/// Parse `.visible .entry <name>(` lines from PTX to find kernel entry points.
289fn parse_kernel_entries(ptx: &str) -> Vec<String> {
290    ptx.lines()
291        .filter_map(|line| {
292            let trimmed = line.trim();
293            if trimmed.starts_with(".visible .entry ") {
294                // Format: .visible .entry kernel_name(
295                let rest = trimmed.strip_prefix(".visible .entry ")?;
296                let name = rest.split('(').next()?.trim();
297                if !name.is_empty() {
298                    Some(name.to_string())
299                } else {
300                    None
301                }
302            } else {
303                None
304            }
305        })
306        .collect()
307}
308
309// ============================================================================
310// Code generation
311// ============================================================================
312
313/// Generate a Rust module with PTX constant and Kernels struct.
314fn generate_rust_module(kernel_crate: &Path, kernels: &[String]) -> String {
315    let mut code = String::new();
316
317    // Header
318    code.push_str(&format!(
319        "// Auto-generated by warp-types-builder. Do not edit.\n\
320         // Kernel crate: {}\n\
321         // Kernel count: {}\n\n",
322        kernel_crate.display(),
323        kernels.len(),
324    ));
325
326    // PTX constant
327    code.push_str(
328        "/// Raw PTX assembly for all kernels in this module.\n\
329         pub const KERNEL_PTX: &str = include_str!(concat!(env!(\"OUT_DIR\"), \"/kernels.ptx\"));\n\n"
330    );
331
332    // Kernels struct
333    code.push_str(
334        "/// Loaded GPU kernel functions.\n\
335         ///\n\
336         /// Created by [`Kernels::load`], which parses the PTX and extracts\n\
337         /// each kernel entry point as a ready-to-launch [`CudaFunction`].\n\
338         ///\n\
339         /// # Available kernels\n",
340    );
341    for name in kernels {
342        code.push_str(&format!("/// - `{}` \n", name));
343    }
344    code.push_str(
345        "pub struct Kernels {\n\
346         _module: ::std::sync::Arc<::cudarc::driver::CudaModule>,\n",
347    );
348    for name in kernels {
349        code.push_str(&format!(
350            "    /// Kernel: `{name}`\n\
351                 pub {name}: ::cudarc::driver::CudaFunction,\n",
352            name = name,
353        ));
354    }
355    code.push_str("}\n\n");
356
357    // Kernels::load impl
358    code.push_str(
359        "impl Kernels {\n\
360         /// Load all kernels from the compiled PTX.\n\
361         ///\n\
362         /// Parses the embedded PTX assembly, loads it as a CUDA module,\n\
363         /// and extracts each kernel entry point by name.\n\
364         pub fn load(ctx: &::std::sync::Arc<::cudarc::driver::CudaContext>) -> \
365             ::std::result::Result<Self, Box<dyn ::std::error::Error>> {\n\
366             let ptx = ::cudarc::nvrtc::Ptx::from_src(KERNEL_PTX.to_string());\n\
367             let module = ctx.load_module(ptx)?;\n",
368    );
369    for name in kernels {
370        code.push_str(&format!(
371            "        let {name} = module.load_function(\"{name}\")?;\n",
372            name = name,
373        ));
374    }
375    code.push_str("        let _module = module;\n");
376    code.push_str("        Ok(Kernels {\n            _module,\n");
377    for name in kernels {
378        code.push_str(&format!("            {},\n", name));
379    }
380    code.push_str("        })\n    }\n}\n");
381
382    code
383}
384
385// ============================================================================
386// File watching
387// ============================================================================
388
389/// Emit `cargo:rerun-if-changed` for all files in the kernel crate recursively.
390fn emit_rerun_if_changed(kernel_dir: &Path) {
391    println!(
392        "cargo:rerun-if-changed={}",
393        kernel_dir.join("Cargo.toml").display()
394    );
395
396    let src_dir = kernel_dir.join("src");
397    if src_dir.exists() {
398        emit_rerun_recursive(&src_dir);
399    }
400}
401
402fn emit_rerun_recursive(dir: &Path) {
403    if let Ok(entries) = std::fs::read_dir(dir) {
404        for entry in entries.flatten() {
405            let path = entry.path();
406            if path.is_dir() {
407                emit_rerun_recursive(&path);
408            } else {
409                println!("cargo:rerun-if-changed={}", path.display());
410            }
411        }
412    }
413}
414
415// ============================================================================
416// PTX file discovery
417// ============================================================================
418
419/// Search for the PTX (.s) file in the target directory.
420fn find_ptx_file(target_dir: &Path, kernel_dir: &Path) -> Result<PathBuf, BuildError> {
421    // Get the crate name from Cargo.toml
422    let cargo_toml = kernel_dir.join("Cargo.toml");
423    let content = std::fs::read_to_string(&cargo_toml).map_err(|e| {
424        BuildError::PtxNotFound(format!("Can't read {}: {}", cargo_toml.display(), e))
425    })?;
426
427    // Simple TOML parsing — find name under [package] section
428    let crate_name = {
429        let mut in_package = false;
430        content
431            .lines()
432            .find_map(|line| {
433                let line = line.trim();
434                if line.starts_with('[') {
435                    in_package = line == "[package]";
436                    return None;
437                }
438                if in_package && line.starts_with("name") {
439                    let val = line.split('=').nth(1)?.trim().trim_matches('"');
440                    return Some(val.replace('-', "_"));
441                }
442                None
443            })
444            .unwrap_or_else(|| "kernels".to_string())
445    };
446
447    // Check common locations for the .s file
448    let candidates = [
449        target_dir.join(format!("{}.s", crate_name)),
450        target_dir.join(format!("lib{}.s", crate_name)),
451        target_dir.join(format!("{}.ptx", crate_name)),
452        target_dir.join("deps").join(format!("{}.s", crate_name)),
453        target_dir.join("deps").join(format!("lib{}.s", crate_name)),
454    ];
455
456    for path in &candidates {
457        if path.exists() {
458            return Ok(path.clone());
459        }
460    }
461
462    // Fallback: search deps/ for .s files matching the crate name pattern
463    let deps = target_dir.join("deps");
464    if let Ok(entries) = std::fs::read_dir(&deps) {
465        for entry in entries.flatten() {
466            let p = entry.path();
467            if p.extension().is_some_and(|e| e == "s") {
468                let fname = p
469                    .file_stem()
470                    .map(|s| s.to_string_lossy().to_string())
471                    .unwrap_or_default();
472                // Match crate_name-HASH.s pattern, skip core/compiler_builtins
473                if fname.starts_with(&crate_name)
474                    && !fname.starts_with("core-")
475                    && !fname.starts_with("compiler_builtins-")
476                {
477                    return Ok(p);
478                }
479            }
480        }
481    }
482
483    // Last resort: any non-core .s file in deps/ (sorted for determinism)
484    if let Ok(entries) = std::fs::read_dir(&deps) {
485        let mut candidates: Vec<PathBuf> = entries
486            .flatten()
487            .map(|e| e.path())
488            .filter(|p| {
489                p.extension().is_some_and(|e| e == "s") && {
490                    let fname = p
491                        .file_stem()
492                        .map(|s| s.to_string_lossy().to_string())
493                        .unwrap_or_default();
494                    !fname.starts_with("core-")
495                        && !fname.starts_with("compiler_builtins-")
496                        && !fname.starts_with("warp_types-")
497                }
498            })
499            .collect();
500        candidates.sort();
501        if let Some(p) = candidates.into_iter().next() {
502            return Ok(p);
503        }
504    }
505
506    Err(BuildError::PtxNotFound(format!(
507        "No .s/.ptx file found in {}. Crate name: '{}'. Checked: {:?}",
508        target_dir.display(),
509        crate_name,
510        candidates
511            .iter()
512            .map(|c| c.display().to_string())
513            .collect::<Vec<_>>()
514    )))
515}