1use std::{
2 cell::UnsafeCell,
3 ffi::{CStr, CString},
4 fmt::{self, Display},
5 ptr, result,
6 sync::atomic::{AtomicBool, Ordering},
7};
8
9use singe_core::impl_enum_display;
10use singe_cuda_sys::nvrtc as sys;
11
12use crate::{
13 architecture::GpuArchitecture,
14 error::{Error, Result},
15 module::ModuleImage,
16 try_nvrtc,
17};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct Version {
21 pub major: i32,
22 pub minor: i32,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27#[non_exhaustive]
28pub enum Status {
29 Success,
30 OutOfMemory,
31 ProgramCreationFailure,
32 InvalidInput,
33 InvalidProgram,
34 InvalidOption,
35 Compilation,
36 BuiltinOperationFailure,
37 NoNameExpressionsAfterCompilation,
38 NoLoweredNamesBeforeCompilation,
39 NameExpressionNotValid,
40 InternalError,
41 TimeFileWriteFailed,
42 NoPchCreateAttempted,
43 PchCreateHeapExhausted,
44 PchCreate,
45 Cancelled,
46 TimeTraceFileWriteFailed,
47 Unknown(u32),
48}
49
50impl Status {
51 pub fn description(self) -> String {
52 match sys::nvrtcResult::try_from(self.raw()) {
53 Ok(status) => unsafe {
54 let ptr = sys::nvrtcGetErrorString(status);
55 if ptr.is_null() {
56 String::from("unknown nvrtc error")
57 } else {
58 CStr::from_ptr(ptr).to_string_lossy().into_owned()
59 }
60 },
61 Err(_) => String::from("unknown nvrtc error"),
62 }
63 }
64
65 pub const fn raw(self) -> u32 {
66 match self {
67 Self::Success => sys::nvrtcResult::NVRTC_SUCCESS as _,
68 Self::OutOfMemory => sys::nvrtcResult::NVRTC_ERROR_OUT_OF_MEMORY as _,
69 Self::ProgramCreationFailure => {
70 sys::nvrtcResult::NVRTC_ERROR_PROGRAM_CREATION_FAILURE as _
71 }
72 Self::InvalidInput => sys::nvrtcResult::NVRTC_ERROR_INVALID_INPUT as _,
73 Self::InvalidProgram => sys::nvrtcResult::NVRTC_ERROR_INVALID_PROGRAM as _,
74 Self::InvalidOption => sys::nvrtcResult::NVRTC_ERROR_INVALID_OPTION as _,
75 Self::Compilation => sys::nvrtcResult::NVRTC_ERROR_COMPILATION as _,
76 Self::BuiltinOperationFailure => {
77 sys::nvrtcResult::NVRTC_ERROR_BUILTIN_OPERATION_FAILURE as _
78 }
79 Self::NoNameExpressionsAfterCompilation => {
80 sys::nvrtcResult::NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION as _
81 }
82 Self::NoLoweredNamesBeforeCompilation => {
83 sys::nvrtcResult::NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION as _
84 }
85 Self::NameExpressionNotValid => {
86 sys::nvrtcResult::NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID as _
87 }
88 Self::InternalError => sys::nvrtcResult::NVRTC_ERROR_INTERNAL_ERROR as _,
89 Self::TimeFileWriteFailed => sys::nvrtcResult::NVRTC_ERROR_TIME_FILE_WRITE_FAILED as _,
90 Self::NoPchCreateAttempted => {
91 sys::nvrtcResult::NVRTC_ERROR_NO_PCH_CREATE_ATTEMPTED as _
92 }
93 Self::PchCreateHeapExhausted => {
94 sys::nvrtcResult::NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED as _
95 }
96 Self::PchCreate => sys::nvrtcResult::NVRTC_ERROR_PCH_CREATE as _,
97 Self::Cancelled => sys::nvrtcResult::NVRTC_ERROR_CANCELLED as _,
98 Self::TimeTraceFileWriteFailed => {
99 sys::nvrtcResult::NVRTC_ERROR_TIME_TRACE_FILE_WRITE_FAILED as _
100 }
101 Self::Unknown(code) => code,
102 }
103 }
104}
105
106impl TryFrom<u32> for Status {
107 type Error = u32;
108
109 fn try_from(code: u32) -> result::Result<Self, u32> {
110 match code {
111 code if code == sys::nvrtcResult::NVRTC_SUCCESS as u32 => Ok(Self::Success),
112 code if code == sys::nvrtcResult::NVRTC_ERROR_OUT_OF_MEMORY as u32 => {
113 Ok(Self::OutOfMemory)
114 }
115 code if code == sys::nvrtcResult::NVRTC_ERROR_PROGRAM_CREATION_FAILURE as u32 => {
116 Ok(Self::ProgramCreationFailure)
117 }
118 code if code == sys::nvrtcResult::NVRTC_ERROR_INVALID_INPUT as u32 => {
119 Ok(Self::InvalidInput)
120 }
121 code if code == sys::nvrtcResult::NVRTC_ERROR_INVALID_PROGRAM as u32 => {
122 Ok(Self::InvalidProgram)
123 }
124 code if code == sys::nvrtcResult::NVRTC_ERROR_INVALID_OPTION as u32 => {
125 Ok(Self::InvalidOption)
126 }
127 code if code == sys::nvrtcResult::NVRTC_ERROR_COMPILATION as u32 => {
128 Ok(Self::Compilation)
129 }
130 code if code == sys::nvrtcResult::NVRTC_ERROR_BUILTIN_OPERATION_FAILURE as u32 => {
131 Ok(Self::BuiltinOperationFailure)
132 }
133 code if code
134 == sys::nvrtcResult::NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION as u32 =>
135 {
136 Ok(Self::NoNameExpressionsAfterCompilation)
137 }
138 code if code
139 == sys::nvrtcResult::NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION as u32 =>
140 {
141 Ok(Self::NoLoweredNamesBeforeCompilation)
142 }
143 code if code == sys::nvrtcResult::NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID as u32 => {
144 Ok(Self::NameExpressionNotValid)
145 }
146 code if code == sys::nvrtcResult::NVRTC_ERROR_INTERNAL_ERROR as u32 => {
147 Ok(Self::InternalError)
148 }
149 code if code == sys::nvrtcResult::NVRTC_ERROR_TIME_FILE_WRITE_FAILED as u32 => {
150 Ok(Self::TimeFileWriteFailed)
151 }
152 code if code == sys::nvrtcResult::NVRTC_ERROR_NO_PCH_CREATE_ATTEMPTED as u32 => {
153 Ok(Self::NoPchCreateAttempted)
154 }
155 code if code == sys::nvrtcResult::NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED as u32 => {
156 Ok(Self::PchCreateHeapExhausted)
157 }
158 code if code == sys::nvrtcResult::NVRTC_ERROR_PCH_CREATE as u32 => Ok(Self::PchCreate),
159 code if code == sys::nvrtcResult::NVRTC_ERROR_CANCELLED as u32 => Ok(Self::Cancelled),
160 code if code == sys::nvrtcResult::NVRTC_ERROR_TIME_TRACE_FILE_WRITE_FAILED as u32 => {
161 Ok(Self::TimeTraceFileWriteFailed)
162 }
163 code => Err(code),
164 }
165 }
166}
167
168impl From<sys::nvrtcResult> for Status {
169 fn from(status: sys::nvrtcResult) -> Self {
170 Self::try_from(status as u32).unwrap_or_else(Self::Unknown)
171 }
172}
173
174impl TryFrom<Status> for sys::nvrtcResult {
175 type Error = Status;
176
177 fn try_from(status: Status) -> result::Result<Self, Status> {
178 match status {
179 Status::Success => Ok(Self::NVRTC_SUCCESS),
180 Status::OutOfMemory => Ok(Self::NVRTC_ERROR_OUT_OF_MEMORY),
181 Status::ProgramCreationFailure => Ok(Self::NVRTC_ERROR_PROGRAM_CREATION_FAILURE),
182 Status::InvalidInput => Ok(Self::NVRTC_ERROR_INVALID_INPUT),
183 Status::InvalidProgram => Ok(Self::NVRTC_ERROR_INVALID_PROGRAM),
184 Status::InvalidOption => Ok(Self::NVRTC_ERROR_INVALID_OPTION),
185 Status::Compilation => Ok(Self::NVRTC_ERROR_COMPILATION),
186 Status::BuiltinOperationFailure => Ok(Self::NVRTC_ERROR_BUILTIN_OPERATION_FAILURE),
187 Status::NoNameExpressionsAfterCompilation => {
188 Ok(Self::NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION)
189 }
190 Status::NoLoweredNamesBeforeCompilation => {
191 Ok(Self::NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION)
192 }
193 Status::NameExpressionNotValid => Ok(Self::NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID),
194 Status::InternalError => Ok(Self::NVRTC_ERROR_INTERNAL_ERROR),
195 Status::TimeFileWriteFailed => Ok(Self::NVRTC_ERROR_TIME_FILE_WRITE_FAILED),
196 Status::NoPchCreateAttempted => Ok(Self::NVRTC_ERROR_NO_PCH_CREATE_ATTEMPTED),
197 Status::PchCreateHeapExhausted => Ok(Self::NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED),
198 Status::PchCreate => Ok(Self::NVRTC_ERROR_PCH_CREATE),
199 Status::Cancelled => Ok(Self::NVRTC_ERROR_CANCELLED),
200 Status::TimeTraceFileWriteFailed => Ok(Self::NVRTC_ERROR_TIME_TRACE_FILE_WRITE_FAILED),
201 Status::Unknown(_) => Err(status),
202 }
203 }
204}
205
206impl Display for Status {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 match self {
209 Self::Success => f.write_str("NVRTC_SUCCESS"),
210 Self::OutOfMemory => f.write_str("NVRTC_ERROR_OUT_OF_MEMORY"),
211 Self::ProgramCreationFailure => f.write_str("NVRTC_ERROR_PROGRAM_CREATION_FAILURE"),
212 Self::InvalidInput => f.write_str("NVRTC_ERROR_INVALID_INPUT"),
213 Self::InvalidProgram => f.write_str("NVRTC_ERROR_INVALID_PROGRAM"),
214 Self::InvalidOption => f.write_str("NVRTC_ERROR_INVALID_OPTION"),
215 Self::Compilation => f.write_str("NVRTC_ERROR_COMPILATION"),
216 Self::BuiltinOperationFailure => f.write_str("NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"),
217 Self::NoNameExpressionsAfterCompilation => {
218 f.write_str("NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION")
219 }
220 Self::NoLoweredNamesBeforeCompilation => {
221 f.write_str("NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION")
222 }
223 Self::NameExpressionNotValid => f.write_str("NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID"),
224 Self::InternalError => f.write_str("NVRTC_ERROR_INTERNAL_ERROR"),
225 Self::TimeFileWriteFailed => f.write_str("NVRTC_ERROR_TIME_FILE_WRITE_FAILED"),
226 Self::NoPchCreateAttempted => f.write_str("NVRTC_ERROR_NO_PCH_CREATE_ATTEMPTED"),
227 Self::PchCreateHeapExhausted => f.write_str("NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED"),
228 Self::PchCreate => f.write_str("NVRTC_ERROR_PCH_CREATE"),
229 Self::Cancelled => f.write_str("NVRTC_ERROR_CANCELLED"),
230 Self::TimeTraceFileWriteFailed => {
231 f.write_str("NVRTC_ERROR_TIME_TRACE_FILE_WRITE_FAILED")
232 }
233 Self::Unknown(code) => write!(f, "UNKNOWN_NVRTC_STATUS({code})"),
234 }
235 }
236}
237
238#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
239pub struct Header<'a> {
240 pub source: &'a str,
241 pub include_name: &'a str,
242}
243
244#[derive(Debug, Clone)]
245struct OwnedHeader {
246 source: String,
247 include_name: String,
248}
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
251#[non_exhaustive]
252pub enum MacroDefinition<'a> {
253 Name(&'a str),
254 WithValue { name: &'a str, value: &'a str },
255}
256
257impl MacroDefinition<'_> {
258 fn format(self) -> String {
259 match self {
260 Self::Name(name) => name.to_string(),
261 Self::WithValue { name, value } => format!("{name}={value}"),
262 }
263 }
264}
265
266#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
267#[non_exhaustive]
268pub enum CppDialect {
269 Cpp03,
270 Cpp11,
271 Cpp14,
272 Cpp17,
273 Cpp20,
274}
275
276impl_enum_display!(CppDialect, {
277 Self::Cpp03 => "c++03",
278 Self::Cpp11 => "c++11",
279 Self::Cpp14 => "c++14",
280 Self::Cpp17 => "c++17",
281 Self::Cpp20 => "c++20",
282});
283
284#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
285#[non_exhaustive]
286pub enum FastCompileLevel {
287 Zero,
288 Min,
289 Mid,
290 Max,
291}
292
293impl_enum_display!(FastCompileLevel, {
294 Self::Zero => "0",
295 Self::Min => "min",
296 Self::Mid => "mid",
297 Self::Max => "max",
298});
299
300#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
301#[non_exhaustive]
302pub enum WarningAsErrorKind {
303 AllWarnings,
304 Reorder,
305 DeprecatedDeclarations,
306}
307
308impl_enum_display!(WarningAsErrorKind, {
309 Self::AllWarnings => "all-warnings",
310 Self::Reorder => "reorder",
311 Self::DeprecatedDeclarations => "deprecated-declarations",
312});
313
314#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
315#[non_exhaustive]
316pub enum OptimizationInfoKind {
317 Inline,
318}
319
320impl_enum_display!(OptimizationInfoKind, {
321 Self::Inline => "inline",
322});
323
324#[derive(Debug, Clone, Default)]
325pub struct CompileOptions<'a> {
326 pub gpu_architecture: Option<GpuArchitecture>,
327 pub relocatable_device_code: Option<bool>,
328 pub extensible_whole_program: bool,
329 pub device_debug: bool,
330 pub generate_line_info: bool,
331 pub device_optimization: Option<bool>,
332 pub fast_compile: Option<FastCompileLevel>,
333 pub ptxas_options: Vec<&'a str>,
334 pub max_register_count: Option<i32>,
335 pub flush_to_zero: Option<bool>,
336 pub precise_square_root: Option<bool>,
337 pub precise_division: Option<bool>,
338 pub fmad: Option<bool>,
339 pub use_fast_math: bool,
340 pub extra_device_vectorization: bool,
341 pub modify_stack_limit: Option<bool>,
342 pub dlink_time_optimization: bool,
343 pub generate_optimized_lto: bool,
344 pub optix_ir: bool,
345 pub jump_table_density: Option<u8>,
346 pub no_cache: bool,
347 pub random_seed: Option<&'a str>,
348 pub define_macros: Vec<MacroDefinition<'a>>,
349 pub undefine_macros: Vec<&'a str>,
350 pub include_paths: Vec<&'a str>,
351 pub pre_include_headers: Vec<&'a str>,
352 pub no_source_include: bool,
353 pub cpp_dialect: Option<CppDialect>,
354 pub builtin_move_forward: Option<bool>,
355 pub builtin_initializer_list: Option<bool>,
356 pub pch: bool,
357 pub create_pch: Option<&'a str>,
358 pub use_pch: Option<&'a str>,
359 pub pch_dir: Option<&'a str>,
360 pub pch_verbose: Option<bool>,
361 pub pch_messages: Option<bool>,
362 pub instantiate_templates_in_pch: Option<bool>,
363 pub disable_warnings: bool,
364 pub warning_as_error: Vec<WarningAsErrorKind>,
365 pub restrict_pointers: bool,
366 pub device_as_default_execution_space: bool,
367 pub device_int128: bool,
368 pub device_float128: bool,
369 pub optimization_info: Vec<OptimizationInfoKind>,
370 pub display_error_number: Option<bool>,
371 pub diag_error: Vec<i32>,
372 pub diag_suppress: Vec<i32>,
373 pub diag_warn: Vec<i32>,
374 pub brief_diagnostics: Option<bool>,
375 pub time: Option<&'a str>,
376 pub split_compile: Option<i32>,
377 pub device_syntax_only: bool,
378 pub minimal: bool,
379 pub device_stack_protector: Option<bool>,
380 pub device_time_trace: Option<&'a str>,
381 pub raw_options: Vec<&'a str>,
382}
383
384impl<'a> CompileOptions<'a> {
385 pub const fn new() -> Self {
386 Self {
387 gpu_architecture: None,
388 relocatable_device_code: None,
389 extensible_whole_program: false,
390 device_debug: false,
391 generate_line_info: false,
392 device_optimization: None,
393 fast_compile: None,
394 ptxas_options: Vec::new(),
395 max_register_count: None,
396 flush_to_zero: None,
397 precise_square_root: None,
398 precise_division: None,
399 fmad: None,
400 use_fast_math: false,
401 extra_device_vectorization: false,
402 modify_stack_limit: None,
403 dlink_time_optimization: false,
404 generate_optimized_lto: false,
405 optix_ir: false,
406 jump_table_density: None,
407 no_cache: false,
408 random_seed: None,
409 define_macros: Vec::new(),
410 undefine_macros: Vec::new(),
411 include_paths: Vec::new(),
412 pre_include_headers: Vec::new(),
413 no_source_include: false,
414 cpp_dialect: None,
415 builtin_move_forward: None,
416 builtin_initializer_list: None,
417 pch: false,
418 create_pch: None,
419 use_pch: None,
420 pch_dir: None,
421 pch_verbose: None,
422 pch_messages: None,
423 instantiate_templates_in_pch: None,
424 disable_warnings: false,
425 warning_as_error: Vec::new(),
426 restrict_pointers: false,
427 device_as_default_execution_space: false,
428 device_int128: false,
429 device_float128: false,
430 optimization_info: Vec::new(),
431 display_error_number: None,
432 diag_error: Vec::new(),
433 diag_suppress: Vec::new(),
434 diag_warn: Vec::new(),
435 brief_diagnostics: None,
436 time: None,
437 split_compile: None,
438 device_syntax_only: false,
439 minimal: false,
440 device_stack_protector: None,
441 device_time_trace: None,
442 raw_options: Vec::new(),
443 }
444 }
445
446 pub fn gpu_architecture(mut self, value: GpuArchitecture) -> Self {
447 self.gpu_architecture = Some(value);
448 self
449 }
450
451 pub fn relocatable_device_code(mut self, value: bool) -> Self {
452 self.relocatable_device_code = Some(value);
453 self
454 }
455
456 pub fn extensible_whole_program(mut self, value: bool) -> Self {
457 self.extensible_whole_program = value;
458 self
459 }
460
461 pub fn device_debug(mut self, value: bool) -> Self {
462 self.device_debug = value;
463 self
464 }
465
466 pub fn generate_line_info(mut self, value: bool) -> Self {
467 self.generate_line_info = value;
468 self
469 }
470
471 pub fn device_optimization(mut self, value: bool) -> Self {
472 self.device_optimization = Some(value);
473 self
474 }
475
476 pub fn fast_compile(mut self, value: FastCompileLevel) -> Self {
477 self.fast_compile = Some(value);
478 self
479 }
480
481 pub fn ptxas_option(mut self, value: &'a str) -> Self {
482 self.ptxas_options.push(value);
483 self
484 }
485
486 pub fn max_register_count(mut self, value: i32) -> Self {
487 self.max_register_count = Some(value);
488 self
489 }
490
491 pub fn flush_to_zero(mut self, value: bool) -> Self {
492 self.flush_to_zero = Some(value);
493 self
494 }
495
496 pub fn precise_square_root(mut self, value: bool) -> Self {
497 self.precise_square_root = Some(value);
498 self
499 }
500
501 pub fn precise_division(mut self, value: bool) -> Self {
502 self.precise_division = Some(value);
503 self
504 }
505
506 pub fn fmad(mut self, value: bool) -> Self {
507 self.fmad = Some(value);
508 self
509 }
510
511 pub fn use_fast_math(mut self, value: bool) -> Self {
512 self.use_fast_math = value;
513 self
514 }
515
516 pub fn extra_device_vectorization(mut self, value: bool) -> Self {
517 self.extra_device_vectorization = value;
518 self
519 }
520
521 pub fn modify_stack_limit(mut self, value: bool) -> Self {
522 self.modify_stack_limit = Some(value);
523 self
524 }
525
526 pub fn dlink_time_optimization(mut self, value: bool) -> Self {
527 self.dlink_time_optimization = value;
528 self
529 }
530
531 pub fn generate_optimized_lto(mut self, value: bool) -> Self {
532 self.generate_optimized_lto = value;
533 self
534 }
535
536 pub fn optix_ir(mut self, value: bool) -> Self {
537 self.optix_ir = value;
538 self
539 }
540
541 pub fn jump_table_density(mut self, value: u8) -> Self {
542 self.jump_table_density = Some(value.min(101));
543 self
544 }
545
546 pub fn no_cache(mut self, value: bool) -> Self {
547 self.no_cache = value;
548 self
549 }
550
551 pub fn random_seed(mut self, value: &'a str) -> Self {
552 self.random_seed = Some(value);
553 self
554 }
555
556 pub fn define_macro(mut self, value: MacroDefinition<'a>) -> Self {
557 self.define_macros.push(value);
558 self
559 }
560
561 pub fn undefine_macro(mut self, value: &'a str) -> Self {
562 self.undefine_macros.push(value);
563 self
564 }
565
566 pub fn include_path(mut self, value: &'a str) -> Self {
567 self.include_paths.push(value);
568 self
569 }
570
571 pub fn pre_include_header(mut self, value: &'a str) -> Self {
572 self.pre_include_headers.push(value);
573 self
574 }
575
576 pub fn no_source_include(mut self, value: bool) -> Self {
577 self.no_source_include = value;
578 self
579 }
580
581 pub fn cpp_dialect(mut self, value: CppDialect) -> Self {
582 self.cpp_dialect = Some(value);
583 self
584 }
585
586 pub fn builtin_move_forward(mut self, value: bool) -> Self {
587 self.builtin_move_forward = Some(value);
588 self
589 }
590
591 pub fn builtin_initializer_list(mut self, value: bool) -> Self {
592 self.builtin_initializer_list = Some(value);
593 self
594 }
595
596 pub fn pch(mut self, value: bool) -> Self {
597 self.pch = value;
598 self
599 }
600
601 pub fn create_pch(mut self, value: &'a str) -> Self {
602 self.create_pch = Some(value);
603 self
604 }
605
606 pub fn use_pch(mut self, value: &'a str) -> Self {
607 self.use_pch = Some(value);
608 self
609 }
610
611 pub fn pch_dir(mut self, value: &'a str) -> Self {
612 self.pch_dir = Some(value);
613 self
614 }
615
616 pub fn pch_verbose(mut self, value: bool) -> Self {
617 self.pch_verbose = Some(value);
618 self
619 }
620
621 pub fn pch_messages(mut self, value: bool) -> Self {
622 self.pch_messages = Some(value);
623 self
624 }
625
626 pub fn instantiate_templates_in_pch(mut self, value: bool) -> Self {
627 self.instantiate_templates_in_pch = Some(value);
628 self
629 }
630
631 pub fn disable_warnings(mut self, value: bool) -> Self {
632 self.disable_warnings = value;
633 self
634 }
635
636 pub fn warning_as_error(mut self, value: WarningAsErrorKind) -> Self {
637 self.warning_as_error.push(value);
638 self
639 }
640
641 pub fn restrict_pointers(mut self, value: bool) -> Self {
642 self.restrict_pointers = value;
643 self
644 }
645
646 pub fn device_as_default_execution_space(mut self, value: bool) -> Self {
647 self.device_as_default_execution_space = value;
648 self
649 }
650
651 pub fn device_int128(mut self, value: bool) -> Self {
652 self.device_int128 = value;
653 self
654 }
655
656 pub fn device_float128(mut self, value: bool) -> Self {
657 self.device_float128 = value;
658 self
659 }
660
661 pub fn optimization_info(mut self, value: OptimizationInfoKind) -> Self {
662 self.optimization_info.push(value);
663 self
664 }
665
666 pub fn display_error_number(mut self, value: bool) -> Self {
667 self.display_error_number = Some(value);
668 self
669 }
670
671 pub fn diag_error(mut self, value: i32) -> Self {
672 self.diag_error.push(value);
673 self
674 }
675
676 pub fn diag_suppress(mut self, value: i32) -> Self {
677 self.diag_suppress.push(value);
678 self
679 }
680
681 pub fn diag_warn(mut self, value: i32) -> Self {
682 self.diag_warn.push(value);
683 self
684 }
685
686 pub fn brief_diagnostics(mut self, value: bool) -> Self {
687 self.brief_diagnostics = Some(value);
688 self
689 }
690
691 pub fn time(mut self, value: &'a str) -> Self {
692 self.time = Some(value);
693 self
694 }
695
696 pub fn split_compile(mut self, value: i32) -> Self {
697 self.split_compile = Some(value);
698 self
699 }
700
701 pub fn device_syntax_only(mut self, value: bool) -> Self {
702 self.device_syntax_only = value;
703 self
704 }
705
706 pub fn minimal(mut self, value: bool) -> Self {
707 self.minimal = value;
708 self
709 }
710
711 pub fn device_stack_protector(mut self, value: bool) -> Self {
712 self.device_stack_protector = Some(value);
713 self
714 }
715
716 pub fn device_time_trace(mut self, value: &'a str) -> Self {
717 self.device_time_trace = Some(value);
718 self
719 }
720
721 pub fn raw_option(mut self, value: &'a str) -> Self {
722 self.raw_options.push(value);
723 self
724 }
725
726 pub fn as_arguments(&self) -> Vec<String> {
727 let mut arguments = Vec::new();
728
729 if let Some(value) = self.gpu_architecture {
730 arguments.push(format!("--gpu-architecture={value}"));
731 }
732 if let Some(value) = self.relocatable_device_code {
733 arguments.push(format!("--relocatable-device-code={}", bool_flag(value)));
734 }
735 if self.extensible_whole_program {
736 arguments.push(String::from("--extensible-whole-program"));
737 }
738 if self.device_debug {
739 arguments.push(String::from("--device-debug"));
740 }
741 if self.generate_line_info {
742 arguments.push(String::from("--generate-line-info"));
743 }
744 if let Some(value) = self.device_optimization
745 && value
746 {
747 arguments.push(String::from("--dopt=on"));
748 }
749 if let Some(value) = self.fast_compile {
750 arguments.push(format!("--Ofast-compile={value}"));
751 }
752 arguments.extend(
753 self.ptxas_options
754 .iter()
755 .map(|value| format!("--ptxas-options={value}")),
756 );
757 if let Some(value) = self.max_register_count {
758 arguments.push(format!("--maxrregcount={value}"));
759 }
760 if let Some(value) = self.flush_to_zero {
761 arguments.push(format!("--ftz={}", bool_flag(value)));
762 }
763 if let Some(value) = self.precise_square_root {
764 arguments.push(format!("--prec-sqrt={}", bool_flag(value)));
765 }
766 if let Some(value) = self.precise_division {
767 arguments.push(format!("--prec-div={}", bool_flag(value)));
768 }
769 if let Some(value) = self.fmad {
770 arguments.push(format!("--fmad={}", bool_flag(value)));
771 }
772 if self.use_fast_math {
773 arguments.push(String::from("--use_fast_math"));
774 }
775 if self.extra_device_vectorization {
776 arguments.push(String::from("--extra-device-vectorization"));
777 }
778 if let Some(value) = self.modify_stack_limit {
779 arguments.push(format!("--modify-stack-limit={}", bool_flag(value)));
780 }
781 if self.dlink_time_optimization {
782 arguments.push(String::from("--dlink-time-opt"));
783 }
784 if self.generate_optimized_lto {
785 arguments.push(String::from("--gen-opt-lto"));
786 }
787 if self.optix_ir {
788 arguments.push(String::from("--optix-ir"));
789 }
790 if let Some(value) = self.jump_table_density {
791 arguments.push(format!("--jump-table-density={value}"));
792 }
793 if self.no_cache {
794 arguments.push(String::from("--no-cache"));
795 }
796 if let Some(value) = self.random_seed {
797 arguments.push(format!("--frandom-seed={value}"));
798 }
799 arguments.extend(
800 self.define_macros
801 .iter()
802 .copied()
803 .map(|value| format!("--define-macro={}", value.format())),
804 );
805 arguments.extend(
806 self.undefine_macros
807 .iter()
808 .map(|value| format!("--undefine-macro={value}")),
809 );
810 arguments.extend(
811 self.include_paths
812 .iter()
813 .map(|value| format!("--include-path={value}")),
814 );
815 arguments.extend(
816 self.pre_include_headers
817 .iter()
818 .map(|value| format!("--pre-include={value}")),
819 );
820 if self.no_source_include {
821 arguments.push(String::from("--no-source-include"));
822 }
823 if let Some(value) = self.cpp_dialect {
824 arguments.push(format!("--std={value}"));
825 }
826 if let Some(value) = self.builtin_move_forward {
827 arguments.push(format!("--builtin-move-forward={}", bool_flag(value)));
828 }
829 if let Some(value) = self.builtin_initializer_list {
830 arguments.push(format!("--builtin-initializer-list={}", bool_flag(value)));
831 }
832 if self.pch {
833 arguments.push(String::from("--pch"));
834 }
835 if let Some(value) = self.create_pch {
836 arguments.push(format!("--create-pch={value}"));
837 }
838 if let Some(value) = self.use_pch {
839 arguments.push(format!("--use-pch={value}"));
840 }
841 if let Some(value) = self.pch_dir {
842 arguments.push(format!("--pch-dir={value}"));
843 }
844 if let Some(value) = self.pch_verbose {
845 arguments.push(format!("--pch-verbose={}", bool_flag(value)));
846 }
847 if let Some(value) = self.pch_messages {
848 arguments.push(format!("--pch-messages={}", bool_flag(value)));
849 }
850 if let Some(value) = self.instantiate_templates_in_pch {
851 arguments.push(format!(
852 "--instantiate-templates-in-pch={}",
853 bool_flag(value)
854 ));
855 }
856 if self.disable_warnings {
857 arguments.push(String::from("--disable-warnings"));
858 }
859 if !self.warning_as_error.is_empty() {
860 arguments.push(format!(
861 "--warning-as-error={}",
862 join_display(&self.warning_as_error)
863 ));
864 }
865 if self.restrict_pointers {
866 arguments.push(String::from("--restrict"));
867 }
868 if self.device_as_default_execution_space {
869 arguments.push(String::from("--device-as-default-execution-space"));
870 }
871 if self.device_int128 {
872 arguments.push(String::from("--device-int128"));
873 }
874 if self.device_float128 {
875 arguments.push(String::from("--device-float128"));
876 }
877 arguments.extend(
878 self.optimization_info
879 .iter()
880 .map(|value| format!("--optimization-info={value}")),
881 );
882 if let Some(value) = self.display_error_number {
883 arguments.push(if value {
884 String::from("--display-error-number")
885 } else {
886 String::from("--no-display-error-number")
887 });
888 }
889 if !self.diag_error.is_empty() {
890 arguments.push(format!("--diag-error={}", join_numbers(&self.diag_error)));
891 }
892 if !self.diag_suppress.is_empty() {
893 arguments.push(format!(
894 "--diag-suppress={}",
895 join_numbers(&self.diag_suppress)
896 ));
897 }
898 if !self.diag_warn.is_empty() {
899 arguments.push(format!("--diag-warn={}", join_numbers(&self.diag_warn)));
900 }
901 if let Some(value) = self.brief_diagnostics {
902 arguments.push(format!("--brief-diagnostics={}", bool_flag(value)));
903 }
904 if let Some(value) = self.time {
905 arguments.push(format!("--time={value}"));
906 }
907 if let Some(value) = self.split_compile {
908 arguments.push(format!("--split-compile={value}"));
909 }
910 if self.device_syntax_only {
911 arguments.push(String::from("--fdevice-syntax-only"));
912 }
913 if self.minimal {
914 arguments.push(String::from("--minimal"));
915 }
916 if let Some(value) = self.device_stack_protector {
917 arguments.push(format!("--device-stack-protector={}", bool_flag(value)));
918 }
919 if let Some(value) = self.device_time_trace {
920 arguments.push(format!("--fdevice-time-trace={value}"));
921 }
922
923 arguments.extend(self.raw_options.iter().map(|value| (*value).to_string()));
924 arguments
925 }
926}
927
928#[derive(Debug)]
929pub struct Program {
930 source: String,
931 name: Option<String>,
932 headers: Vec<OwnedHeader>,
933 handle: UnsafeCell<sys::nvrtcProgram>,
934}
935
936impl Program {
937 pub fn from_source(source: &str) -> Self {
942 Self {
943 source: source.to_string(),
944 name: None,
945 headers: Vec::new(),
946 handle: UnsafeCell::new(ptr::null_mut()),
947 }
948 }
949
950 pub fn with_name(mut self, name: &str) -> Self {
951 self.name = Some(name.to_string());
952 self
953 }
954
955 pub fn with_header(mut self, header: Header<'_>) -> Self {
956 self.headers.push(OwnedHeader {
957 source: header.source.to_string(),
958 include_name: header.include_name.to_string(),
959 });
960 self
961 }
962
963 pub fn with_headers(mut self, headers: &[Header<'_>]) -> Self {
964 self.headers
965 .extend(headers.iter().map(|header| OwnedHeader {
966 source: header.source.to_string(),
967 include_name: header.include_name.to_string(),
968 }));
969 self
970 }
971
972 pub unsafe fn from_raw(handle: sys::nvrtcProgram) -> Result<Self> {
979 if handle.is_null() {
980 return Err(Error::NullHandle);
981 }
982
983 Ok(Self {
984 source: String::new(),
985 name: None,
986 headers: Vec::new(),
987 handle: UnsafeCell::new(handle),
988 })
989 }
990
991 pub fn compile(&self, options: &[&str]) -> Result<()> {
992 self.compile_raw(options)
993 }
994
995 pub fn compile_with_options(&self, options: &CompileOptions<'_>) -> Result<()> {
996 let arguments = options.as_arguments();
997 let argument_refs = arguments.iter().map(String::as_str).collect::<Vec<_>>();
998 self.compile_raw(&argument_refs)
999 }
1000
1001 pub fn compile_with_options_and_cancel_flag(
1015 &self,
1016 options: &CompileOptions<'_>,
1017 cancel: &AtomicBool,
1018 ) -> Result<()> {
1019 unsafe {
1020 try_nvrtc!(sys::nvrtcSetFlowCallback(
1021 self.handle()?,
1022 Some(cancel_if_requested_callback),
1023 ptr::from_ref(cancel).cast_mut().cast(),
1024 ))?;
1025 }
1026
1027 let compile_result = self.compile_with_options(options);
1028 let clear_result = clear_flow_callback(self);
1029
1030 match (compile_result, clear_result) {
1031 (Err(error), _) => Err(error),
1032 (Ok(()), Err(error)) => Err(error),
1033 (Ok(()), Ok(())) => Ok(()),
1034 }
1035 }
1036
1037 pub fn add_name_expression(&self, name_expression: &str) -> Result<()> {
1046 let name_expression = CString::new(name_expression)?;
1047 unsafe {
1048 try_nvrtc!(sys::nvrtcAddNameExpression(
1049 self.handle()?,
1050 name_expression.as_ptr(),
1051 ))
1052 }
1053 }
1054
1055 pub fn lowered_name(&self, name_expression: &str) -> Result<String> {
1064 let name_expression = CString::new(name_expression)?;
1065 let mut lowered_name = ptr::null();
1066 unsafe {
1067 try_nvrtc!(sys::nvrtcGetLoweredName(
1068 self.handle()?,
1069 name_expression.as_ptr(),
1070 &raw mut lowered_name,
1071 ))?;
1072 Ok(CStr::from_ptr(lowered_name).to_string_lossy().into_owned())
1073 }
1074 }
1075
1076 pub fn ptx(&self) -> Result<Vec<u8>> {
1077 self.bytes(sys::nvrtcGetPTXSize, sys::nvrtcGetPTX)
1078 }
1079
1080 pub fn ptx_image(&self) -> Result<ModuleImage<'static>> {
1081 Ok(ModuleImage::from_vec(self.ptx()?))
1082 }
1083
1084 pub fn ptx_string(&self) -> Result<String> {
1085 Ok(bytes_to_string(self.ptx()?))
1086 }
1087
1088 pub fn cubin(&self) -> Result<Vec<u8>> {
1089 self.bytes(sys::nvrtcGetCUBINSize, sys::nvrtcGetCUBIN)
1090 }
1091
1092 pub fn cubin_image(&self) -> Result<ModuleImage<'static>> {
1093 Ok(ModuleImage::from_vec(self.cubin()?))
1094 }
1095
1096 pub fn lto_ir(&self) -> Result<Vec<u8>> {
1097 self.bytes(sys::nvrtcGetLTOIRSize, sys::nvrtcGetLTOIR)
1098 }
1099
1100 pub fn lto_ir_image(&self) -> Result<ModuleImage<'static>> {
1101 Ok(ModuleImage::from_vec(self.lto_ir()?))
1102 }
1103
1104 pub fn optix_ir(&self) -> Result<Vec<u8>> {
1105 self.bytes(sys::nvrtcGetOptiXIRSize, sys::nvrtcGetOptiXIR)
1106 }
1107
1108 pub fn optix_ir_image(&self) -> Result<ModuleImage<'static>> {
1109 Ok(ModuleImage::from_vec(self.optix_ir()?))
1110 }
1111
1112 pub fn log(&self) -> Result<String> {
1113 Ok(bytes_to_string(self.bytes(
1114 sys::nvrtcGetProgramLogSize,
1115 sys::nvrtcGetProgramLog,
1116 )?))
1117 }
1118
1119 pub fn pch_create_status(&self) -> Result<Status> {
1131 unsafe { Ok(sys::nvrtcGetPCHCreateStatus(self.handle()?).into()) }
1132 }
1133
1134 pub fn pch_heap_size_required(&self) -> Result<usize> {
1142 let mut size = 0;
1143 unsafe {
1144 try_nvrtc!(sys::nvrtcGetPCHHeapSizeRequired(
1145 self.handle()?,
1146 &raw mut size
1147 ))?;
1148 }
1149 Ok(size as usize)
1150 }
1151
1152 pub fn materialize_raw(&self) -> Result<sys::nvrtcProgram> {
1161 self.handle()
1162 }
1163
1164 pub const fn as_raw_if_materialized(&self) -> Option<sys::nvrtcProgram> {
1167 let handle = unsafe { *self.handle.get() };
1168 if handle.is_null() { None } else { Some(handle) }
1169 }
1170
1171 pub fn into_raw(self) -> Result<sys::nvrtcProgram> {
1177 let handle = self.handle()?;
1178 std::mem::forget(self);
1179 Ok(handle)
1180 }
1181
1182 fn compile_raw(&self, options: &[&str]) -> Result<()> {
1183 let options = options
1184 .iter()
1185 .map(|option| CString::new(*option))
1186 .collect::<result::Result<Vec<_>, _>>()?;
1187 let option_ptrs = options
1188 .iter()
1189 .map(|value| value.as_ptr())
1190 .collect::<Vec<_>>();
1191
1192 unsafe {
1193 try_nvrtc!(sys::nvrtcCompileProgram(
1194 self.handle()?,
1195 option_ptrs.len() as _,
1196 if option_ptrs.is_empty() {
1197 ptr::null()
1198 } else {
1199 option_ptrs.as_ptr()
1200 },
1201 ))
1202 }
1203 }
1204
1205 fn bytes(
1206 &self,
1207 get_size: unsafe extern "C" fn(sys::nvrtcProgram, *mut sys::size_t) -> sys::nvrtcResult,
1208 get_data: unsafe extern "C" fn(sys::nvrtcProgram, *mut i8) -> sys::nvrtcResult,
1209 ) -> Result<Vec<u8>> {
1210 let mut size = 0;
1211 unsafe {
1212 try_nvrtc!(get_size(self.handle()?, &raw mut size))?;
1213 }
1214
1215 let mut bytes = vec![0u8; size as usize];
1216 if bytes.is_empty() {
1217 return Ok(bytes);
1218 }
1219
1220 unsafe {
1221 try_nvrtc!(get_data(self.handle()?, bytes.as_mut_ptr().cast()))?;
1222 }
1223 Ok(bytes)
1224 }
1225
1226 fn handle(&self) -> Result<sys::nvrtcProgram> {
1227 unsafe {
1228 let handle = self.handle.get();
1229 if (*handle).is_null() {
1230 *handle = self.create_handle()?;
1231 }
1232 Ok(*handle)
1233 }
1234 }
1235
1236 fn create_handle(&self) -> Result<sys::nvrtcProgram> {
1237 let source = CString::new(self.source.as_str())?;
1238 let name = self.name.as_deref().map(CString::new).transpose()?;
1239 let header_sources = self
1240 .headers
1241 .iter()
1242 .map(|header| CString::new(header.source.as_str()))
1243 .collect::<result::Result<Vec<_>, _>>()?;
1244 let include_names = self
1245 .headers
1246 .iter()
1247 .map(|header| CString::new(header.include_name.as_str()))
1248 .collect::<result::Result<Vec<_>, _>>()?;
1249 let header_ptrs = header_sources
1250 .iter()
1251 .map(|value| value.as_ptr())
1252 .collect::<Vec<_>>();
1253 let include_name_ptrs = include_names
1254 .iter()
1255 .map(|value| value.as_ptr())
1256 .collect::<Vec<_>>();
1257
1258 let mut handle = ptr::null_mut();
1259 unsafe {
1260 try_nvrtc!(sys::nvrtcCreateProgram(
1261 &raw mut handle,
1262 source.as_ptr(),
1263 name.as_ref().map_or(ptr::null(), |value| value.as_ptr()),
1264 self.headers.len() as _,
1265 if header_ptrs.is_empty() {
1266 ptr::null()
1267 } else {
1268 header_ptrs.as_ptr()
1269 },
1270 if include_name_ptrs.is_empty() {
1271 ptr::null()
1272 } else {
1273 include_name_ptrs.as_ptr()
1274 },
1275 ))?;
1276 }
1277
1278 Ok(handle)
1279 }
1280}
1281
1282impl Drop for Program {
1283 fn drop(&mut self) {
1284 unsafe {
1285 let handle = self.handle.get();
1286 if !(*handle).is_null() {
1287 let _ = sys::nvrtcDestroyProgram(handle);
1288 }
1289 }
1290 }
1291}
1292
1293#[non_exhaustive]
1294pub enum CompilationArtifact {
1295 Ptx(ModuleImage<'static>),
1296 Cubin(ModuleImage<'static>),
1297 LtoIr(ModuleImage<'static>),
1298 OptixIr(ModuleImage<'static>),
1299}
1300
1301impl CompilationArtifact {
1302 pub fn image(&self) -> &ModuleImage<'static> {
1303 match self {
1304 Self::Ptx(image) | Self::Cubin(image) | Self::LtoIr(image) | Self::OptixIr(image) => {
1305 image
1306 }
1307 }
1308 }
1309
1310 pub fn into_image(self) -> ModuleImage<'static> {
1311 match self {
1312 Self::Ptx(image) | Self::Cubin(image) | Self::LtoIr(image) | Self::OptixIr(image) => {
1313 image
1314 }
1315 }
1316 }
1317}
1318
1319#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1320#[non_exhaustive]
1321pub enum OutputKind {
1322 Ptx,
1323 Cubin,
1324 LtoIr,
1325 OptixIr,
1326}
1327
1328impl Program {
1329 pub fn artifact(&self, kind: OutputKind) -> Result<CompilationArtifact> {
1330 match kind {
1331 OutputKind::Ptx => {
1332 let image = self.ptx_image()?;
1333 if image.as_bytes().is_empty() {
1334 return Err(Error::InvalidValue);
1335 }
1336 Ok(CompilationArtifact::Ptx(image))
1337 }
1338 OutputKind::Cubin => {
1339 let image = self.cubin_image()?;
1340 if image.as_bytes().is_empty() {
1341 return Err(Error::InvalidValue);
1342 }
1343 Ok(CompilationArtifact::Cubin(image))
1344 }
1345 OutputKind::LtoIr => {
1346 let image = self.lto_ir_image()?;
1347 if image.as_bytes().is_empty() {
1348 return Err(Error::InvalidValue);
1349 }
1350 Ok(CompilationArtifact::LtoIr(image))
1351 }
1352 OutputKind::OptixIr => {
1353 let image = self.optix_ir_image()?;
1354 if image.as_bytes().is_empty() {
1355 return Err(Error::InvalidValue);
1356 }
1357 Ok(CompilationArtifact::OptixIr(image))
1358 }
1359 }
1360 }
1361}
1362
1363pub fn version() -> Result<Version> {
1369 let mut major = 0;
1370 let mut minor = 0;
1371 unsafe {
1372 try_nvrtc!(sys::nvrtcVersion(&raw mut major, &raw mut minor))?;
1373 }
1374 Ok(Version { major, minor })
1375}
1376
1377pub fn supported_architectures() -> Result<Vec<i32>> {
1378 let mut count = 0;
1379 unsafe {
1380 try_nvrtc!(sys::nvrtcGetNumSupportedArchs(&raw mut count))?;
1381 }
1382
1383 let mut architectures = vec![0; count as usize];
1384 if architectures.is_empty() {
1385 return Ok(Vec::new());
1386 }
1387
1388 unsafe {
1389 try_nvrtc!(sys::nvrtcGetSupportedArchs(architectures.as_mut_ptr()))?;
1390 }
1391 Ok(architectures)
1392}
1393
1394pub fn pch_heap_size() -> Result<usize> {
1400 let mut size = 0;
1401 unsafe {
1402 try_nvrtc!(sys::nvrtcGetPCHHeapSize(&raw mut size))?;
1403 }
1404 Ok(size as usize)
1405}
1406
1407pub fn set_pch_heap_size(size: usize) -> Result<()> {
1416 unsafe { try_nvrtc!(sys::nvrtcSetPCHHeapSize(size as _)) }
1417}
1418
1419unsafe extern "C" fn noop_compile_callback(
1420 payload: *mut std::ffi::c_void,
1421 reserved: *mut std::ffi::c_void,
1422) -> i32 {
1423 let _ = payload;
1424 let _ = reserved;
1425 0
1426}
1427
1428unsafe extern "C" fn cancel_if_requested_callback(
1429 payload: *mut std::ffi::c_void,
1430 reserved: *mut std::ffi::c_void,
1431) -> i32 {
1432 let _ = reserved;
1433 if payload.is_null() {
1434 return 0;
1435 }
1436
1437 let cancel = unsafe { &*payload.cast::<AtomicBool>() };
1438 i32::from(cancel.load(Ordering::Relaxed))
1439}
1440
1441pub fn clear_flow_callback(program: &Program) -> Result<()> {
1450 unsafe {
1451 try_nvrtc!(sys::nvrtcSetFlowCallback(
1452 program.handle()?,
1453 Some(noop_compile_callback),
1454 ptr::null_mut(),
1455 ))
1456 }
1457}
1458
1459fn bool_flag(value: bool) -> &'static str {
1460 if value { "true" } else { "false" }
1461}
1462
1463fn join_display(values: &[impl Display]) -> String {
1464 values
1465 .iter()
1466 .map(ToString::to_string)
1467 .collect::<Vec<_>>()
1468 .join(",")
1469}
1470
1471fn join_numbers(values: &[i32]) -> String {
1472 values
1473 .iter()
1474 .map(ToString::to_string)
1475 .collect::<Vec<_>>()
1476 .join(",")
1477}
1478
1479fn bytes_to_string(mut bytes: Vec<u8>) -> String {
1480 while bytes.last() == Some(&0) {
1481 bytes.pop();
1482 }
1483 String::from_utf8_lossy(&bytes).into_owned()
1484}
1485
1486#[cfg(all(test, feature = "testing"))]
1487mod tests {
1488 use super::*;
1489 use crate::{
1490 device::Device, error::Result, memory::DeviceMemory, module::LaunchConfig, testing,
1491 };
1492
1493 fn current_device_sm_architecture() -> Result<GpuArchitecture> {
1494 let properties = Device::current()?.properties()?;
1495 Ok(match (properties.major, properties.minor) {
1496 (7, 5) => GpuArchitecture::Sm75,
1497 (8, 0) => GpuArchitecture::Sm80,
1498 (8, 6) => GpuArchitecture::Sm86,
1499 (8, 7) => GpuArchitecture::Sm87,
1500 (8, 9) => GpuArchitecture::Sm89,
1501 (9, 0) => GpuArchitecture::Sm90,
1502 (10, 0) => GpuArchitecture::Sm100,
1503 (10, 1) => GpuArchitecture::Sm101,
1504 (10, 3) => GpuArchitecture::Sm103,
1505 (12, 0) => GpuArchitecture::Sm120,
1506 (12, 1) => GpuArchitecture::Sm121,
1507 (major, minor) => panic!("unsupported device architecture sm_{major}{minor}"),
1508 })
1509 }
1510
1511 #[test]
1512 fn version_is_available() {
1513 let version = version().unwrap();
1514 assert_ne!(version.major, 0);
1515 }
1516
1517 #[test]
1518 fn supported_architectures_are_sorted() {
1519 let architectures = supported_architectures().unwrap();
1520 assert!(!architectures.is_empty());
1521 assert!(
1522 architectures
1523 .windows(2)
1524 .all(|window| window[0] <= window[1])
1525 );
1526 }
1527
1528 #[test]
1529 fn compile_options_build_expected_arguments() {
1530 let arguments = CompileOptions::new()
1531 .gpu_architecture(GpuArchitecture::Compute80)
1532 .device_debug(true)
1533 .generate_line_info(true)
1534 .define_macro(MacroDefinition::WithValue {
1535 name: "FOO",
1536 value: "42",
1537 })
1538 .include_path("include")
1539 .cpp_dialect(CppDialect::Cpp20)
1540 .warning_as_error(WarningAsErrorKind::Reorder)
1541 .diag_suppress(177)
1542 .raw_option("--custom-flag")
1543 .as_arguments();
1544
1545 assert_eq!(
1546 arguments,
1547 vec![
1548 "--gpu-architecture=compute_80",
1549 "--device-debug",
1550 "--generate-line-info",
1551 "--define-macro=FOO=42",
1552 "--include-path=include",
1553 "--std=c++20",
1554 "--warning-as-error=reorder",
1555 "--diag-suppress=177",
1556 "--custom-flag",
1557 ]
1558 );
1559 }
1560
1561 #[test]
1562 fn compiles_to_ptx() {
1563 let program = Program::from_source(
1564 r#"
1565 extern "C" __global__ void saxpy(float a, const float* x, const float* y, float* out, size_t n) {
1566 size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
1567 if (tid < n) {
1568 out[tid] = a * x[tid] + y[tid];
1569 }
1570 }
1571 "#,
1572 )
1573 .with_name("saxpy.cu");
1574 let options = CompileOptions::new()
1575 .gpu_architecture(GpuArchitecture::Compute80)
1576 .generate_line_info(true);
1577 program.compile_with_options(&options).unwrap();
1578
1579 let ptx = program.ptx_string().unwrap();
1580 assert!(ptx.contains(".visible .entry saxpy"));
1581 }
1582
1583 #[test]
1584 fn compile_with_cancel_flag_succeeds_when_not_cancelled() {
1585 let (_lock, _ctx) = testing::bootstrap().unwrap();
1586 let cancel = AtomicBool::new(false);
1587 let program = Program::from_source(
1588 r#"
1589 extern "C" __global__ void noop() {}
1590 "#,
1591 )
1592 .with_name("noop.cu");
1593 let options = CompileOptions::new().gpu_architecture(GpuArchitecture::Compute80);
1594
1595 program
1596 .compile_with_options_and_cancel_flag(&options, &cancel)
1597 .unwrap();
1598 assert!(program.ptx_string().unwrap().contains("noop"));
1599 }
1600
1601 #[test]
1602 fn clear_flow_callback_is_allowed_before_compilation() {
1603 let program =
1604 Program::from_source("extern \"C\" __global__ void noop() {}").with_name("noop.cu");
1605 clear_flow_callback(&program).unwrap();
1606 }
1607
1608 #[test]
1609 fn cubin_artifact_loads_as_module() {
1610 let (_lock, ctx) = testing::bootstrap().unwrap();
1611 let program = Program::from_source(
1612 r#"
1613 extern "C" __global__ void saxpy(float a, const float* x, const float* y, float* out, size_t n) {
1614 size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
1615 if (tid < n) {
1616 out[tid] = a * x[tid] + y[tid];
1617 }
1618 }
1619 "#,
1620 )
1621 .with_name("saxpy_module.cu");
1622 let architecture = current_device_sm_architecture().unwrap();
1623 let options = CompileOptions::new().gpu_architecture(architecture);
1624 program.compile_with_options(&options).unwrap();
1625
1626 let module = ctx.load_nvrtc_module(&program, OutputKind::Cubin).unwrap();
1627 assert!(module.function("saxpy").is_ok());
1628 }
1629
1630 #[test]
1631 fn cubin_artifact_loads_as_module_with_jit_options() {
1632 let (_lock, ctx) = testing::bootstrap().unwrap();
1633 let program = Program::from_source(
1634 r#"
1635 extern "C" __global__ void noop() {}
1636 "#,
1637 )
1638 .with_name("noop_module_jit.cu");
1639 let architecture = current_device_sm_architecture().unwrap();
1640 let options = CompileOptions::new().gpu_architecture(architecture);
1641 program.compile_with_options(&options).unwrap();
1642
1643 let mut info_log = [0u8; 1024];
1644 let mut error_log = [0u8; 1024];
1645 let jit_options = crate::jit::JitOptions::default()
1646 .with_generate_line_info(true)
1647 .with_info_log(&mut info_log)
1648 .with_error_log(&mut error_log);
1649
1650 let module = ctx
1651 .load_nvrtc_module_with_options(&program, OutputKind::Cubin, jit_options)
1652 .unwrap();
1653 assert!(module.function("noop").is_ok());
1654 }
1655
1656 #[test]
1657 fn compiles_loads_launches_and_reads_back_results() {
1658 let (_lock, ctx) = testing::bootstrap().unwrap();
1659
1660 let input = vec![1.0f32, 2.0, 3.5, -4.0, 8.25];
1661 let mut output = vec![0.0f32; input.len()];
1662 let scalar = 2.5f32;
1663 let input_device = DeviceMemory::from_slice(&input).unwrap();
1664 let output_device = DeviceMemory::<f32>::zeroes(output.len()).unwrap();
1665 let length = input.len();
1666
1667 let program = Program::from_source(
1668 r#"
1669 extern "C" __global__ void scale_add(const float* input, float* output, float alpha, size_t len) {
1670 size_t i = blockIdx.x * blockDim.x + threadIdx.x;
1671 if (i < len) {
1672 output[i] = input[i] * alpha + 1.0f;
1673 }
1674 }
1675 "#,
1676 )
1677 .with_name("scale_add.cu");
1678 let architecture = current_device_sm_architecture().unwrap();
1679 let compile_options = CompileOptions::new().gpu_architecture(architecture);
1680 program.compile_with_options(&compile_options).unwrap();
1681
1682 let module = ctx.load_nvrtc_module(&program, OutputKind::Cubin).unwrap();
1683 let function = module.function("scale_add").unwrap();
1684
1685 let config = LaunchConfig::for_1d_grid(input.len(), 128);
1686 let input_ptr = input_device.as_ptr();
1687 let mut output_ptr = output_device.as_mut_ptr();
1688
1689 function
1690 .launch(&config, (&input_ptr, &mut output_ptr, &scalar, &length))
1691 .unwrap();
1692 output_device.copy_to_host(&mut output).unwrap();
1693
1694 let expected = input
1695 .iter()
1696 .map(|value| value * scalar + 1.0)
1697 .collect::<Vec<_>>();
1698 assert_eq!(output, expected);
1699 }
1700
1701 #[test]
1702 fn cubin_artifact_loads_as_library() {
1703 let (_lock, ctx) = testing::bootstrap().unwrap();
1704 let program = Program::from_source(
1705 r#"
1706 extern "C" __global__ void noop() {}
1707 "#,
1708 )
1709 .with_name("noop_library.cu");
1710 let architecture = current_device_sm_architecture().unwrap();
1711 let options = CompileOptions::new().gpu_architecture(architecture);
1712 program.compile_with_options(&options).unwrap();
1713
1714 let library = ctx.load_nvrtc_library(&program, OutputKind::Cubin).unwrap();
1715 assert!(library.kernel_count().unwrap() >= 1);
1716 }
1717
1718 #[test]
1719 fn lto_ir_artifact_is_available_when_requested() {
1720 let program = Program::from_source(
1721 r#"
1722 extern "C" __global__ void noop() {}
1723 "#,
1724 )
1725 .with_name("noop_lto.cu");
1726 let options = CompileOptions::new().dlink_time_optimization(true);
1727 program.compile_with_options(&options).unwrap();
1728
1729 let artifact = program.artifact(OutputKind::LtoIr).unwrap();
1730 assert!(!artifact.image().as_bytes().is_empty());
1731 let ptx = program.artifact(OutputKind::Ptx).unwrap();
1732 assert!(!ptx.image().as_bytes().is_empty());
1733 }
1734
1735 #[test]
1736 fn optix_ir_artifact_is_available_when_requested() {
1737 let program = Program::from_source(
1738 r#"
1739 extern "C" __global__ void noop() {}
1740 "#,
1741 )
1742 .with_name("noop_optix.cu");
1743 let options = CompileOptions::new().optix_ir(true);
1744 program.compile_with_options(&options).unwrap();
1745
1746 let artifact = program.artifact(OutputKind::OptixIr).unwrap();
1747 assert!(!artifact.image().as_bytes().is_empty());
1748 let ptx = program.artifact(OutputKind::Ptx).unwrap();
1749 assert!(!ptx.image().as_bytes().is_empty());
1750 }
1751
1752 #[test]
1753 fn cubin_artifact_requires_real_architecture() {
1754 let program = Program::from_source(
1755 r#"
1756 extern "C" __global__ void noop() {}
1757 "#,
1758 )
1759 .with_name("noop_cubin.cu");
1760 let options = CompileOptions::new().gpu_architecture(GpuArchitecture::Compute80);
1761 program.compile_with_options(&options).unwrap();
1762
1763 assert!(program.artifact(OutputKind::Cubin).is_err());
1764 }
1765}