1use std::env;
31use std::path::{Path, PathBuf};
32use std::process::Command;
33
34#[derive(Clone, Debug)]
36pub enum GpuTarget {
37 Nvidia,
39 Amd,
41}
42
43impl GpuTarget {
44 fn triple(&self) -> &str {
45 match self {
46 GpuTarget::Nvidia => "nvptx64-nvidia-cuda",
47 GpuTarget::Amd => "amdgcn-amd-amdhsa",
48 }
49 }
50
51 #[allow(dead_code)]
52 fn asm_extension(&self) -> &str {
53 match self {
54 GpuTarget::Nvidia => "s", GpuTarget::Amd => "s", }
57 }
58}
59
60pub struct WarpBuilder {
62 kernel_crate: PathBuf,
64 toolchain: String,
66 release: bool,
68 features: Vec<String>,
70 target: GpuTarget,
72 sm_arch: String,
79}
80
81impl WarpBuilder {
82 pub fn new(kernel_crate_path: impl Into<PathBuf>) -> Self {
87 WarpBuilder {
88 kernel_crate: kernel_crate_path.into(),
89 toolchain: "nightly".to_string(),
90 release: true,
91 features: Vec::new(),
92 target: GpuTarget::Nvidia,
93 sm_arch: "sm_70".to_string(),
94 }
95 }
96
97 pub fn toolchain(mut self, toolchain: impl Into<String>) -> Self {
99 self.toolchain = toolchain.into();
100 self
101 }
102
103 pub fn debug(mut self) -> Self {
105 self.release = false;
106 self
107 }
108
109 pub fn target(mut self, target: GpuTarget) -> Self {
111 self.target = target;
112 self
113 }
114
115 pub fn sm_arch(mut self, arch: impl Into<String>) -> Self {
120 self.sm_arch = arch.into();
121 self
122 }
123
124 pub fn feature(mut self, feature: impl Into<String>) -> Self {
126 self.features.push(feature.into());
127 self
128 }
129
130 pub fn build(self) -> Result<BuildResult, BuildError> {
139 let manifest_dir =
140 env::var("CARGO_MANIFEST_DIR").map_err(|_| BuildError::NotInBuildScript)?;
141 let out_dir = env::var("OUT_DIR").map_err(|_| BuildError::NotInBuildScript)?;
142
143 let kernel_dir = Path::new(&manifest_dir).join(&self.kernel_crate);
144 if !kernel_dir.exists() {
145 return Err(BuildError::KernelCrateNotFound(kernel_dir));
146 }
147
148 emit_rerun_if_changed(&kernel_dir);
150
151 let mut cmd = Command::new("cargo");
156 cmd.arg("rustc")
157 .arg("--target")
158 .arg(self.target.triple())
159 .arg("-Z")
160 .arg("build-std=core")
161 .current_dir(&kernel_dir);
162
163 if self.release {
164 cmd.arg("--release");
165 }
166
167 for feat in &self.features {
169 cmd.arg("--features").arg(feat);
170 }
171
172 cmd.arg("--")
177 .arg("--emit=asm")
178 .arg("-C")
179 .arg(format!("target-cpu={}", self.sm_arch));
180
181 cmd.env("RUSTUP_TOOLCHAIN", &self.toolchain);
187 cmd.env_remove("RUSTC");
188 cmd.env("RUSTFLAGS", "--cfg warp_kernel_build");
189
190 let output = cmd
191 .output()
192 .map_err(|e| BuildError::CargoFailed(format!("Failed to run cargo: {}", e)))?;
193
194 if !output.status.success() {
195 let stderr = String::from_utf8_lossy(&output.stderr);
196 return Err(BuildError::CompilationFailed(stderr.into_owned()));
197 }
198
199 let profile = if self.release { "release" } else { "debug" };
201 let target_dir = kernel_dir
202 .join("target")
203 .join(self.target.triple())
204 .join(profile);
205
206 let ptx_path = find_ptx_file(&target_dir, &kernel_dir)?;
207
208 let ptx_content = std::fs::read_to_string(&ptx_path)
210 .map_err(|e| BuildError::PtxReadFailed(format!("{}: {}", ptx_path.display(), e)))?;
211
212 let kernels = parse_kernel_entries(&ptx_content);
214
215 let out_ptx = Path::new(&out_dir).join("kernels.ptx");
217 std::fs::write(&out_ptx, &ptx_content)
218 .map_err(|e| BuildError::WriteFailed(format!("{}: {}", out_ptx.display(), e)))?;
219
220 let out_rs = Path::new(&out_dir).join("kernels.rs");
222 let rs_content = generate_rust_module(&self.kernel_crate, &kernels);
223 std::fs::write(&out_rs, &rs_content)
224 .map_err(|e| BuildError::WriteFailed(format!("{}: {}", out_rs.display(), e)))?;
225
226 Ok(BuildResult {
227 ptx_path: out_ptx,
228 module_path: out_rs,
229 kernel_names: kernels,
230 })
231 }
232}
233
234pub struct BuildResult {
236 pub ptx_path: PathBuf,
238 pub module_path: PathBuf,
240 pub kernel_names: Vec<String>,
242}
243
244#[derive(Debug)]
246pub enum BuildError {
247 NotInBuildScript,
249 KernelCrateNotFound(PathBuf),
251 CargoFailed(String),
253 CompilationFailed(String),
255 PtxNotFound(String),
257 PtxReadFailed(String),
259 WriteFailed(String),
261}
262
263impl std::fmt::Display for BuildError {
264 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265 match self {
266 BuildError::NotInBuildScript => write!(
267 f,
268 "Not running in a build script (CARGO_MANIFEST_DIR not set)"
269 ),
270 BuildError::KernelCrateNotFound(p) => {
271 write!(f, "Kernel crate not found: {}", p.display())
272 }
273 BuildError::CargoFailed(s) => write!(f, "Cargo invocation failed: {}", s),
274 BuildError::CompilationFailed(s) => write!(f, "Kernel compilation failed:\n{}", s),
275 BuildError::PtxNotFound(s) => write!(f, "PTX file not found: {}", s),
276 BuildError::PtxReadFailed(s) => write!(f, "Failed to read PTX: {}", s),
277 BuildError::WriteFailed(s) => write!(f, "Failed to write output: {}", s),
278 }
279 }
280}
281
282impl std::error::Error for BuildError {}
283
284fn parse_kernel_entries(ptx: &str) -> Vec<String> {
290 ptx.lines()
291 .filter_map(|line| {
292 let trimmed = line.trim();
293 if trimmed.starts_with(".visible .entry ") {
294 let rest = trimmed.strip_prefix(".visible .entry ")?;
296 let name = rest.split('(').next()?.trim();
297 if !name.is_empty() {
298 Some(name.to_string())
299 } else {
300 None
301 }
302 } else {
303 None
304 }
305 })
306 .collect()
307}
308
309fn generate_rust_module(kernel_crate: &Path, kernels: &[String]) -> String {
315 let mut code = String::new();
316
317 code.push_str(&format!(
319 "// Auto-generated by warp-types-builder. Do not edit.\n\
320 // Kernel crate: {}\n\
321 // Kernel count: {}\n\n",
322 kernel_crate.display(),
323 kernels.len(),
324 ));
325
326 code.push_str(
328 "/// Raw PTX assembly for all kernels in this module.\n\
329 pub const KERNEL_PTX: &str = include_str!(concat!(env!(\"OUT_DIR\"), \"/kernels.ptx\"));\n\n"
330 );
331
332 code.push_str(
334 "/// Loaded GPU kernel functions.\n\
335 ///\n\
336 /// Created by [`Kernels::load`], which parses the PTX and extracts\n\
337 /// each kernel entry point as a ready-to-launch [`CudaFunction`].\n\
338 ///\n\
339 /// # Available kernels\n",
340 );
341 for name in kernels {
342 code.push_str(&format!("/// - `{}` \n", name));
343 }
344 code.push_str(
345 "pub struct Kernels {\n\
346 _module: ::std::sync::Arc<::cudarc::driver::CudaModule>,\n",
347 );
348 for name in kernels {
349 code.push_str(&format!(
350 " /// Kernel: `{name}`\n\
351 pub {name}: ::cudarc::driver::CudaFunction,\n",
352 name = name,
353 ));
354 }
355 code.push_str("}\n\n");
356
357 code.push_str(
359 "impl Kernels {\n\
360 /// Load all kernels from the compiled PTX.\n\
361 ///\n\
362 /// Parses the embedded PTX assembly, loads it as a CUDA module,\n\
363 /// and extracts each kernel entry point by name.\n\
364 pub fn load(ctx: &::std::sync::Arc<::cudarc::driver::CudaContext>) -> \
365 ::std::result::Result<Self, Box<dyn ::std::error::Error>> {\n\
366 let ptx = ::cudarc::nvrtc::Ptx::from_src(KERNEL_PTX.to_string());\n\
367 let module = ctx.load_module(ptx)?;\n",
368 );
369 for name in kernels {
370 code.push_str(&format!(
371 " let {name} = module.load_function(\"{name}\")?;\n",
372 name = name,
373 ));
374 }
375 code.push_str(" let _module = module;\n");
376 code.push_str(" Ok(Kernels {\n _module,\n");
377 for name in kernels {
378 code.push_str(&format!(" {},\n", name));
379 }
380 code.push_str(" })\n }\n}\n");
381
382 code
383}
384
385fn emit_rerun_if_changed(kernel_dir: &Path) {
391 println!(
392 "cargo:rerun-if-changed={}",
393 kernel_dir.join("Cargo.toml").display()
394 );
395
396 let src_dir = kernel_dir.join("src");
397 if src_dir.exists() {
398 emit_rerun_recursive(&src_dir);
399 }
400}
401
402fn emit_rerun_recursive(dir: &Path) {
403 if let Ok(entries) = std::fs::read_dir(dir) {
404 for entry in entries.flatten() {
405 let path = entry.path();
406 if path.is_dir() {
407 emit_rerun_recursive(&path);
408 } else {
409 println!("cargo:rerun-if-changed={}", path.display());
410 }
411 }
412 }
413}
414
415fn find_ptx_file(target_dir: &Path, kernel_dir: &Path) -> Result<PathBuf, BuildError> {
421 let cargo_toml = kernel_dir.join("Cargo.toml");
423 let content = std::fs::read_to_string(&cargo_toml).map_err(|e| {
424 BuildError::PtxNotFound(format!("Can't read {}: {}", cargo_toml.display(), e))
425 })?;
426
427 let crate_name = {
429 let mut in_package = false;
430 content
431 .lines()
432 .find_map(|line| {
433 let line = line.trim();
434 if line.starts_with('[') {
435 in_package = line == "[package]";
436 return None;
437 }
438 if in_package && line.starts_with("name") {
439 let val = line.split('=').nth(1)?.trim().trim_matches('"');
440 return Some(val.replace('-', "_"));
441 }
442 None
443 })
444 .unwrap_or_else(|| "kernels".to_string())
445 };
446
447 let candidates = [
449 target_dir.join(format!("{}.s", crate_name)),
450 target_dir.join(format!("lib{}.s", crate_name)),
451 target_dir.join(format!("{}.ptx", crate_name)),
452 target_dir.join("deps").join(format!("{}.s", crate_name)),
453 target_dir.join("deps").join(format!("lib{}.s", crate_name)),
454 ];
455
456 for path in &candidates {
457 if path.exists() {
458 return Ok(path.clone());
459 }
460 }
461
462 let deps = target_dir.join("deps");
464 if let Ok(entries) = std::fs::read_dir(&deps) {
465 for entry in entries.flatten() {
466 let p = entry.path();
467 if p.extension().is_some_and(|e| e == "s") {
468 let fname = p
469 .file_stem()
470 .map(|s| s.to_string_lossy().to_string())
471 .unwrap_or_default();
472 if fname.starts_with(&crate_name)
474 && !fname.starts_with("core-")
475 && !fname.starts_with("compiler_builtins-")
476 {
477 return Ok(p);
478 }
479 }
480 }
481 }
482
483 if let Ok(entries) = std::fs::read_dir(&deps) {
485 let mut candidates: Vec<PathBuf> = entries
486 .flatten()
487 .map(|e| e.path())
488 .filter(|p| {
489 p.extension().is_some_and(|e| e == "s") && {
490 let fname = p
491 .file_stem()
492 .map(|s| s.to_string_lossy().to_string())
493 .unwrap_or_default();
494 !fname.starts_with("core-")
495 && !fname.starts_with("compiler_builtins-")
496 && !fname.starts_with("warp_types-")
497 }
498 })
499 .collect();
500 candidates.sort();
501 if let Some(p) = candidates.into_iter().next() {
502 return Ok(p);
503 }
504 }
505
506 Err(BuildError::PtxNotFound(format!(
507 "No .s/.ptx file found in {}. Crate name: '{}'. Checked: {:?}",
508 target_dir.display(),
509 crate_name,
510 candidates
511 .iter()
512 .map(|c| c.display().to_string())
513 .collect::<Vec<_>>()
514 )))
515}