1use std::env;
31use std::path::{Path, PathBuf};
32use std::process::Command;
33
34#[derive(Clone, Debug)]
36pub enum GpuTarget {
37 Nvidia,
39 Amd,
41}
42
43impl GpuTarget {
44 fn triple(&self) -> &str {
45 match self {
46 GpuTarget::Nvidia => "nvptx64-nvidia-cuda",
47 GpuTarget::Amd => "amdgcn-amd-amdhsa",
48 }
49 }
50
51 #[allow(dead_code)]
52 fn asm_extension(&self) -> &str {
53 match self {
54 GpuTarget::Nvidia => "s", GpuTarget::Amd => "s", }
57 }
58}
59
60pub struct WarpBuilder {
62 kernel_crate: PathBuf,
64 toolchain: String,
66 release: bool,
68 features: Vec<String>,
70 target: GpuTarget,
72}
73
74impl WarpBuilder {
75 pub fn new(kernel_crate_path: impl Into<PathBuf>) -> Self {
80 WarpBuilder {
81 kernel_crate: kernel_crate_path.into(),
82 toolchain: "nightly".to_string(),
83 release: true,
84 features: Vec::new(),
85 target: GpuTarget::Nvidia,
86 }
87 }
88
89 pub fn toolchain(mut self, toolchain: impl Into<String>) -> Self {
91 self.toolchain = toolchain.into();
92 self
93 }
94
95 pub fn debug(mut self) -> Self {
97 self.release = false;
98 self
99 }
100
101 pub fn target(mut self, target: GpuTarget) -> Self {
103 self.target = target;
104 self
105 }
106
107 pub fn feature(mut self, feature: impl Into<String>) -> Self {
109 self.features.push(feature.into());
110 self
111 }
112
113 pub fn build(self) -> Result<BuildResult, BuildError> {
122 let manifest_dir =
123 env::var("CARGO_MANIFEST_DIR").map_err(|_| BuildError::NotInBuildScript)?;
124 let out_dir = env::var("OUT_DIR").map_err(|_| BuildError::NotInBuildScript)?;
125
126 let kernel_dir = Path::new(&manifest_dir).join(&self.kernel_crate);
127 if !kernel_dir.exists() {
128 return Err(BuildError::KernelCrateNotFound(kernel_dir));
129 }
130
131 emit_rerun_if_changed(&kernel_dir);
133
134 let mut cmd = Command::new("cargo");
139 cmd.arg("rustc")
140 .arg("--target")
141 .arg(self.target.triple())
142 .arg("-Z")
143 .arg("build-std=core")
144 .current_dir(&kernel_dir);
145
146 if self.release {
147 cmd.arg("--release");
148 }
149
150 for feat in &self.features {
152 cmd.arg("--features").arg(feat);
153 }
154
155 cmd.arg("--").arg("--emit=asm");
157
158 cmd.env("RUSTUP_TOOLCHAIN", &self.toolchain);
164 cmd.env_remove("RUSTC");
165 cmd.env("RUSTFLAGS", "--cfg warp_kernel_build");
166
167 let output = cmd
168 .output()
169 .map_err(|e| BuildError::CargoFailed(format!("Failed to run cargo: {}", e)))?;
170
171 if !output.status.success() {
172 let stderr = String::from_utf8_lossy(&output.stderr);
173 return Err(BuildError::CompilationFailed(stderr.into_owned()));
174 }
175
176 let profile = if self.release { "release" } else { "debug" };
178 let target_dir = kernel_dir
179 .join("target")
180 .join(self.target.triple())
181 .join(profile);
182
183 let ptx_path = find_ptx_file(&target_dir, &kernel_dir)?;
184
185 let ptx_content = std::fs::read_to_string(&ptx_path)
187 .map_err(|e| BuildError::PtxReadFailed(format!("{}: {}", ptx_path.display(), e)))?;
188
189 let kernels = parse_kernel_entries(&ptx_content);
191
192 let out_ptx = Path::new(&out_dir).join("kernels.ptx");
194 std::fs::write(&out_ptx, &ptx_content)
195 .map_err(|e| BuildError::WriteFailed(format!("{}: {}", out_ptx.display(), e)))?;
196
197 let out_rs = Path::new(&out_dir).join("kernels.rs");
199 let rs_content = generate_rust_module(&self.kernel_crate, &kernels);
200 std::fs::write(&out_rs, &rs_content)
201 .map_err(|e| BuildError::WriteFailed(format!("{}: {}", out_rs.display(), e)))?;
202
203 Ok(BuildResult {
204 ptx_path: out_ptx,
205 module_path: out_rs,
206 kernel_names: kernels,
207 })
208 }
209}
210
211pub struct BuildResult {
213 pub ptx_path: PathBuf,
215 pub module_path: PathBuf,
217 pub kernel_names: Vec<String>,
219}
220
221#[derive(Debug)]
223pub enum BuildError {
224 NotInBuildScript,
226 KernelCrateNotFound(PathBuf),
228 CargoFailed(String),
230 CompilationFailed(String),
232 PtxNotFound(String),
234 PtxReadFailed(String),
236 WriteFailed(String),
238}
239
240impl std::fmt::Display for BuildError {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 match self {
243 BuildError::NotInBuildScript => write!(
244 f,
245 "Not running in a build script (CARGO_MANIFEST_DIR not set)"
246 ),
247 BuildError::KernelCrateNotFound(p) => {
248 write!(f, "Kernel crate not found: {}", p.display())
249 }
250 BuildError::CargoFailed(s) => write!(f, "Cargo invocation failed: {}", s),
251 BuildError::CompilationFailed(s) => write!(f, "Kernel compilation failed:\n{}", s),
252 BuildError::PtxNotFound(s) => write!(f, "PTX file not found: {}", s),
253 BuildError::PtxReadFailed(s) => write!(f, "Failed to read PTX: {}", s),
254 BuildError::WriteFailed(s) => write!(f, "Failed to write output: {}", s),
255 }
256 }
257}
258
259impl std::error::Error for BuildError {}
260
261fn parse_kernel_entries(ptx: &str) -> Vec<String> {
267 ptx.lines()
268 .filter_map(|line| {
269 let trimmed = line.trim();
270 if trimmed.starts_with(".visible .entry ") {
271 let rest = trimmed.strip_prefix(".visible .entry ")?;
273 let name = rest.split('(').next()?.trim();
274 if !name.is_empty() {
275 Some(name.to_string())
276 } else {
277 None
278 }
279 } else {
280 None
281 }
282 })
283 .collect()
284}
285
286fn generate_rust_module(kernel_crate: &Path, kernels: &[String]) -> String {
292 let mut code = String::new();
293
294 code.push_str(&format!(
296 "// Auto-generated by warp-types-builder. Do not edit.\n\
297 // Kernel crate: {}\n\
298 // Kernel count: {}\n\n",
299 kernel_crate.display(),
300 kernels.len(),
301 ));
302
303 code.push_str(
305 "/// Raw PTX assembly for all kernels in this module.\n\
306 pub const KERNEL_PTX: &str = include_str!(concat!(env!(\"OUT_DIR\"), \"/kernels.ptx\"));\n\n"
307 );
308
309 code.push_str(
311 "/// Loaded GPU kernel functions.\n\
312 ///\n\
313 /// Created by [`Kernels::load`], which parses the PTX and extracts\n\
314 /// each kernel entry point as a ready-to-launch [`CudaFunction`].\n\
315 ///\n\
316 /// # Available kernels\n",
317 );
318 for name in kernels {
319 code.push_str(&format!("/// - `{}` \n", name));
320 }
321 code.push_str(
322 "pub struct Kernels {\n\
323 _module: ::std::sync::Arc<::cudarc::driver::CudaModule>,\n",
324 );
325 for name in kernels {
326 code.push_str(&format!(
327 " /// Kernel: `{name}`\n\
328 pub {name}: ::cudarc::driver::CudaFunction,\n",
329 name = name,
330 ));
331 }
332 code.push_str("}\n\n");
333
334 code.push_str(
336 "impl Kernels {\n\
337 /// Load all kernels from the compiled PTX.\n\
338 ///\n\
339 /// Parses the embedded PTX assembly, loads it as a CUDA module,\n\
340 /// and extracts each kernel entry point by name.\n\
341 pub fn load(ctx: &::std::sync::Arc<::cudarc::driver::CudaContext>) -> \
342 ::std::result::Result<Self, Box<dyn ::std::error::Error>> {\n\
343 let ptx = ::cudarc::nvrtc::Ptx::from_src(KERNEL_PTX.to_string());\n\
344 let module = ctx.load_module(ptx)?;\n",
345 );
346 for name in kernels {
347 code.push_str(&format!(
348 " let {name} = module.load_function(\"{name}\")?;\n",
349 name = name,
350 ));
351 }
352 code.push_str(" let _module = module;\n");
353 code.push_str(" Ok(Kernels {\n _module,\n");
354 for name in kernels {
355 code.push_str(&format!(" {},\n", name));
356 }
357 code.push_str(" })\n }\n}\n");
358
359 code
360}
361
362fn emit_rerun_if_changed(kernel_dir: &Path) {
368 println!(
369 "cargo:rerun-if-changed={}",
370 kernel_dir.join("Cargo.toml").display()
371 );
372
373 let src_dir = kernel_dir.join("src");
374 if src_dir.exists() {
375 emit_rerun_recursive(&src_dir);
376 }
377}
378
379fn emit_rerun_recursive(dir: &Path) {
380 if let Ok(entries) = std::fs::read_dir(dir) {
381 for entry in entries.flatten() {
382 let path = entry.path();
383 if path.is_dir() {
384 emit_rerun_recursive(&path);
385 } else {
386 println!("cargo:rerun-if-changed={}", path.display());
387 }
388 }
389 }
390}
391
392fn find_ptx_file(target_dir: &Path, kernel_dir: &Path) -> Result<PathBuf, BuildError> {
398 let cargo_toml = kernel_dir.join("Cargo.toml");
400 let content = std::fs::read_to_string(&cargo_toml).map_err(|e| {
401 BuildError::PtxNotFound(format!("Can't read {}: {}", cargo_toml.display(), e))
402 })?;
403
404 let crate_name = content
406 .lines()
407 .find_map(|line| {
408 let line = line.trim();
409 if line.starts_with("name") {
410 let val = line.split('=').nth(1)?.trim().trim_matches('"');
411 Some(val.replace('-', "_"))
412 } else {
413 None
414 }
415 })
416 .unwrap_or_else(|| "kernels".to_string());
417
418 let candidates = [
420 target_dir.join(format!("{}.s", crate_name)),
421 target_dir.join(format!("lib{}.s", crate_name)),
422 target_dir.join(format!("{}.ptx", crate_name)),
423 target_dir.join("deps").join(format!("{}.s", crate_name)),
424 target_dir.join("deps").join(format!("lib{}.s", crate_name)),
425 ];
426
427 for path in &candidates {
428 if path.exists() {
429 return Ok(path.clone());
430 }
431 }
432
433 let deps = target_dir.join("deps");
435 if let Ok(entries) = std::fs::read_dir(&deps) {
436 for entry in entries.flatten() {
437 let p = entry.path();
438 if p.extension().is_some_and(|e| e == "s") {
439 let fname = p
440 .file_stem()
441 .map(|s| s.to_string_lossy().to_string())
442 .unwrap_or_default();
443 if fname.starts_with(&crate_name)
445 && !fname.starts_with("core-")
446 && !fname.starts_with("compiler_builtins-")
447 {
448 return Ok(p);
449 }
450 }
451 }
452 }
453
454 if let Ok(entries) = std::fs::read_dir(&deps) {
456 for entry in entries.flatten() {
457 let p = entry.path();
458 if p.extension().is_some_and(|e| e == "s") {
459 let fname = p
460 .file_stem()
461 .map(|s| s.to_string_lossy().to_string())
462 .unwrap_or_default();
463 if !fname.starts_with("core-")
464 && !fname.starts_with("compiler_builtins-")
465 && !fname.starts_with("warp_types-")
466 {
467 return Ok(p);
468 }
469 }
470 }
471 }
472
473 Err(BuildError::PtxNotFound(format!(
474 "No .s/.ptx file found in {}. Crate name: '{}'. Checked: {:?}",
475 target_dir.display(),
476 crate_name,
477 candidates
478 .iter()
479 .map(|c| c.display().to_string())
480 .collect::<Vec<_>>()
481 )))
482}