1use xpile_backend::{
20 Artifact, Backend, BackendConfig, BackendError, EmittedText, HwProfile, MultiEmitterBackend,
21 QuorumPolicy, Target, TargetEmitter,
22};
23use xpile_contracts::ContractId;
24use xpile_meta_hir::Module;
25
26pub struct PtxBackend {
30 inner: MultiEmitterBackend,
31}
32
33impl Default for PtxBackend {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl PtxBackend {
40 pub fn new() -> Self {
41 Self {
42 inner: MultiEmitterBackend::new_single(Target::Ptx, Box::new(ScaffoldPtxEmitter)),
43 }
44 }
45
46 pub fn new_with_matmul_specialist() -> Self {
66 Self {
67 inner: MultiEmitterBackend::new_with_specialist(
68 Target::Ptx,
69 Box::new(ScaffoldPtxEmitter),
70 Box::new(MatmulSpecialistEmitter),
71 QuorumPolicy::PreferSpecialist,
72 ),
73 }
74 }
75}
76
77impl Backend for PtxBackend {
78 fn name(&self) -> &'static str {
79 "ptx"
80 }
81
82 fn targets(&self) -> &[Target] {
83 &[Target::Ptx]
84 }
85
86 fn lower(&self, module: &Module, config: &BackendConfig) -> Result<Artifact, BackendError> {
87 match &config.hardware {
91 Some(HwProfile::Ptx { .. }) => {}
92 _ => return Err(BackendError::MissingHardware(Target::Ptx)),
93 }
94 self.inner.lower(module, config)
95 }
96}
97
98struct ScaffoldPtxEmitter;
102
103impl TargetEmitter for ScaffoldPtxEmitter {
104 fn name(&self) -> &str {
105 "xpile-ptx-codegen-scaffold"
106 }
107
108 fn try_emit(
109 &self,
110 module: &Module,
111 config: &BackendConfig,
112 ) -> Option<Result<EmittedText, BackendError>> {
113 let compute_capability = match &config.hardware {
114 Some(HwProfile::Ptx { compute_capability }) => compute_capability,
115 _ => return Some(Err(BackendError::MissingHardware(Target::Ptx))),
116 };
117 Some(Ok(EmittedText {
118 primary: format!(
119 "// xpile-ptx-codegen scaffold\n// module: {}\n// compute_capability: {}\n// TODO: lower to real PTX via rustc_codegen_nvvm.\n",
120 module.name, compute_capability,
121 ),
122 citations: vec![ContractId::new("C-COMPILE-RUST-TO-PTX-MMA")],
123 }))
124 }
125}
126
127struct MatmulSpecialistEmitter;
143
144impl TargetEmitter for MatmulSpecialistEmitter {
145 fn name(&self) -> &str {
146 "matmul-specialist-mock"
147 }
148
149 fn try_emit(
150 &self,
151 module: &Module,
152 config: &BackendConfig,
153 ) -> Option<Result<EmittedText, BackendError>> {
154 if !module.name.starts_with("matmul_") {
155 return None;
156 }
157 let compute_capability = match &config.hardware {
158 Some(HwProfile::Ptx { compute_capability }) => compute_capability,
159 _ => return Some(Err(BackendError::MissingHardware(Target::Ptx))),
160 };
161 Some(Ok(EmittedText {
162 primary: format!(
163 "// matmul-specialist scaffold\n// module: {}\n// compute_capability: {}\n// TODO: emit mma.sync.aligned via aprender-gpu shape templates.\n",
164 module.name, compute_capability,
165 ),
166 citations: vec![ContractId::new("C-COMPILE-RUST-TO-PTX-MMA")],
167 }))
168 }
169}
170
171#[derive(Debug, Clone, PartialEq, Eq)]
184pub enum PtxValidationError {
185 MissingVersion,
187 MissingTarget,
189 TargetMismatch { expected: String, found: String },
191 MissingAddressSize,
193 MissingEntry,
195}
196
197impl std::fmt::Display for PtxValidationError {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 match self {
200 Self::MissingVersion => write!(f, "PTX is missing a `.version` directive"),
201 Self::MissingTarget => write!(f, "PTX is missing a `.target` directive"),
202 Self::TargetMismatch { expected, found } => write!(
203 f,
204 "PTX `.target {found}` does not match requested compute capability `{expected}`"
205 ),
206 Self::MissingAddressSize => write!(f, "PTX is missing `.address_size 64`"),
207 Self::MissingEntry => write!(f, "PTX has no `.visible .entry` kernel"),
208 }
209 }
210}
211
212impl std::error::Error for PtxValidationError {}
213
214pub fn ptx_looks_real(text: &str) -> bool {
217 directive_present(text, ".version")
218}
219
220pub fn ptxas_arch(compute_capability: &str) -> String {
225 format!("-arch={compute_capability}")
226}
227
228pub fn validate_ptx(text: &str, compute_capability: &str) -> Result<(), PtxValidationError> {
234 if !directive_present(text, ".version") {
235 return Err(PtxValidationError::MissingVersion);
236 }
237 let target = ptx_target_arch(text).ok_or(PtxValidationError::MissingTarget)?;
238 if target != compute_capability {
239 return Err(PtxValidationError::TargetMismatch {
240 expected: compute_capability.to_string(),
241 found: target,
242 });
243 }
244 if !directive_present(text, ".address_size 64") {
245 return Err(PtxValidationError::MissingAddressSize);
246 }
247 if !text.contains(".visible .entry") {
248 return Err(PtxValidationError::MissingEntry);
249 }
250 Ok(())
251}
252
253fn directive_present(text: &str, directive: &str) -> bool {
255 text.lines()
256 .map(str::trim)
257 .filter(|l| !l.starts_with("//"))
258 .any(|l| l.starts_with(directive))
259}
260
261fn ptx_target_arch(text: &str) -> Option<String> {
263 text.lines().map(str::trim).find_map(|l| {
264 if l.starts_with("//") {
265 return None;
266 }
267 let rest = l.strip_prefix(".target")?;
268 if !rest.is_empty() && !rest.starts_with(char::is_whitespace) {
269 return None; }
271 let arch = rest.trim().split([',', ' ']).next().unwrap_or("").trim();
272 (!arch.is_empty()).then(|| arch.to_string())
273 })
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use xpile_backend::{Profile, QuorumStatus};
280 use xpile_meta_hir::SourceLang;
281
282 fn dummy_module() -> Module {
283 Module {
284 name: "test_kernel".into(),
285 source_lang: SourceLang::Rust,
286 items: Vec::new(),
287 ffi_boundaries: Vec::new(),
288 }
289 }
290
291 fn ptx_config(sm: &str) -> BackendConfig {
292 BackendConfig {
293 target: Target::Ptx,
294 profile: Profile::RustOut,
295 hardware: Some(HwProfile::Ptx {
296 compute_capability: sm.to_string(),
297 }),
298 }
299 }
300
301 #[test]
302 fn ptx_backend_emits_through_multi_emitter() {
303 let backend = PtxBackend::new();
304 let artifact = backend
305 .lower(&dummy_module(), &ptx_config("sm_80"))
306 .unwrap();
307 assert_eq!(
310 artifact.quorum_status,
311 QuorumStatus::Single {
312 emitter: "xpile-ptx-codegen-scaffold".to_string()
313 }
314 );
315 assert!(artifact.primary.contains("sm_80"));
316 assert!(artifact
317 .citations
318 .iter()
319 .any(|c| c.as_str() == "C-COMPILE-RUST-TO-PTX-MMA"));
320 }
321
322 #[test]
323 fn ptx_backend_rejects_missing_hardware() {
324 let backend = PtxBackend::new();
325 let cfg = BackendConfig {
326 target: Target::Ptx,
327 profile: Profile::RustOut,
328 hardware: None,
329 };
330 let err = backend.lower(&dummy_module(), &cfg).unwrap_err();
331 assert!(matches!(err, BackendError::MissingHardware(Target::Ptx)));
332 }
333
334 #[test]
335 fn ptx_backend_targets_only_ptx() {
336 let backend = PtxBackend::new();
337 assert_eq!(backend.targets(), &[Target::Ptx]);
338 assert_eq!(backend.name(), "ptx");
339 }
340
341 fn matmul_module() -> Module {
344 Module {
345 name: "matmul_gemm_fp16".into(),
346 source_lang: SourceLang::Rust,
347 items: Vec::new(),
348 ffi_boundaries: Vec::new(),
349 }
350 }
351
352 #[test]
357 fn matmul_module_routes_through_specialist_under_multi_emitter() {
358 let backend = PtxBackend::new_with_matmul_specialist();
359 let artifact = backend
360 .lower(&matmul_module(), &ptx_config("sm_80"))
361 .unwrap();
362 assert_eq!(
363 artifact.quorum_status,
364 QuorumStatus::Single {
365 emitter: "matmul-specialist-mock".to_string()
366 },
367 "PreferSpecialist with matching specialist should report Single {{ specialist }}"
368 );
369 assert!(
370 artifact.primary.contains("matmul-specialist"),
371 "primary should carry the specialist's emission body, got:\n{}",
372 artifact.primary,
373 );
374 }
375
376 #[test]
381 fn non_matmul_module_falls_back_to_general_under_multi_emitter() {
382 let backend = PtxBackend::new_with_matmul_specialist();
383 let artifact = backend
384 .lower(&dummy_module(), &ptx_config("sm_80"))
385 .unwrap();
386 assert_eq!(
387 artifact.quorum_status,
388 QuorumStatus::Single {
389 emitter: "xpile-ptx-codegen-scaffold".to_string()
390 },
391 "non-matching specialist should let general emit; QuorumStatus should reflect general"
392 );
393 assert!(
394 artifact.primary.contains("xpile-ptx-codegen scaffold"),
395 "primary should carry the general scaffold's emission body, got:\n{}",
396 artifact.primary,
397 );
398 }
399
400 #[test]
404 fn multi_emitter_constructor_targets_match_single_emitter() {
405 let multi = PtxBackend::new_with_matmul_specialist();
406 let single = PtxBackend::new();
407 assert_eq!(multi.targets(), single.targets());
408 assert_eq!(multi.name(), single.name());
409 }
410
411 #[test]
415 fn multi_emitter_constructor_rejects_missing_hardware() {
416 let backend = PtxBackend::new_with_matmul_specialist();
417 let cfg = BackendConfig {
418 target: Target::Ptx,
419 profile: Profile::RustOut,
420 hardware: None,
421 };
422 let err = backend.lower(&matmul_module(), &cfg).unwrap_err();
423 assert!(matches!(err, BackendError::MissingHardware(Target::Ptx)));
424 }
425
426 const GOLDEN_PTX_SM80: &str = "\
431//
432// Generated by LLVM NVPTX Back-End
433//
434.version 6.0
435.target sm_80
436.address_size 64
437
438\t.visible .entry add_one(
439\t\t.param .u64 add_one_param_0
440\t)
441\t{
442\t\tret;
443\t}
444";
445
446 #[test]
447 fn validate_ptx_accepts_well_formed_kernel() {
448 assert_eq!(validate_ptx(GOLDEN_PTX_SM80, "sm_80"), Ok(()));
449 }
450
451 #[test]
452 fn ptx_looks_real_classifies_golden_vs_scaffold() {
453 assert!(ptx_looks_real(GOLDEN_PTX_SM80));
454 let scaffold = PtxBackend::new()
457 .lower(&dummy_module(), &ptx_config("sm_80"))
458 .unwrap()
459 .primary;
460 assert!(!ptx_looks_real(&scaffold));
461 }
462
463 #[test]
464 fn validate_ptx_rejects_scaffold_placeholder() {
465 let scaffold = PtxBackend::new()
466 .lower(&dummy_module(), &ptx_config("sm_80"))
467 .unwrap()
468 .primary;
469 assert_eq!(
470 validate_ptx(&scaffold, "sm_80"),
471 Err(PtxValidationError::MissingVersion)
472 );
473 }
474
475 #[test]
476 fn validate_ptx_detects_target_mismatch() {
477 assert_eq!(
479 validate_ptx(GOLDEN_PTX_SM80, "sm_90"),
480 Err(PtxValidationError::TargetMismatch {
481 expected: "sm_90".into(),
482 found: "sm_80".into(),
483 })
484 );
485 }
486
487 #[test]
488 fn validate_ptx_requires_address_size_and_entry() {
489 let no_addr = ".version 6.0\n.target sm_80\n.visible .entry k() { ret; }\n";
490 assert_eq!(
491 validate_ptx(no_addr, "sm_80"),
492 Err(PtxValidationError::MissingAddressSize)
493 );
494 let no_entry = ".version 6.0\n.target sm_80\n.address_size 64\n";
495 assert_eq!(
496 validate_ptx(no_entry, "sm_80"),
497 Err(PtxValidationError::MissingEntry)
498 );
499 }
500
501 #[test]
502 fn ptxas_arch_derives_from_capability_not_hardcoded() {
503 assert_eq!(ptxas_arch("sm_89"), "-arch=sm_89");
504 assert_eq!(ptxas_arch("sm_90"), "-arch=sm_90");
505 }
506}