Skip to main content

singe_cuda/
nvrtc.rs

1use std::{
2    cell::UnsafeCell,
3    ffi::{CStr, CString},
4    fmt::{self, Display},
5    ptr, result,
6    sync::atomic::{AtomicBool, Ordering},
7};
8
9use singe_core::impl_enum_display;
10use singe_cuda_sys::nvrtc as sys;
11
12use crate::{
13    architecture::GpuArchitecture,
14    error::{Error, Result},
15    module::ModuleImage,
16    try_nvrtc,
17};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct Version {
21    pub major: i32,
22    pub minor: i32,
23}
24
25/// NVRTC result code returned by compiler operations.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27#[non_exhaustive]
28pub enum Status {
29    Success,
30    OutOfMemory,
31    ProgramCreationFailure,
32    InvalidInput,
33    InvalidProgram,
34    InvalidOption,
35    Compilation,
36    BuiltinOperationFailure,
37    NoNameExpressionsAfterCompilation,
38    NoLoweredNamesBeforeCompilation,
39    NameExpressionNotValid,
40    InternalError,
41    TimeFileWriteFailed,
42    NoPchCreateAttempted,
43    PchCreateHeapExhausted,
44    PchCreate,
45    Cancelled,
46    TimeTraceFileWriteFailed,
47    Unknown(u32),
48}
49
50impl Status {
51    pub fn description(self) -> String {
52        match sys::nvrtcResult::try_from(self.raw()) {
53            Ok(status) => unsafe {
54                let ptr = sys::nvrtcGetErrorString(status);
55                if ptr.is_null() {
56                    String::from("unknown nvrtc error")
57                } else {
58                    CStr::from_ptr(ptr).to_string_lossy().into_owned()
59                }
60            },
61            Err(_) => String::from("unknown nvrtc error"),
62        }
63    }
64
65    pub const fn raw(self) -> u32 {
66        match self {
67            Self::Success => sys::nvrtcResult::NVRTC_SUCCESS as _,
68            Self::OutOfMemory => sys::nvrtcResult::NVRTC_ERROR_OUT_OF_MEMORY as _,
69            Self::ProgramCreationFailure => {
70                sys::nvrtcResult::NVRTC_ERROR_PROGRAM_CREATION_FAILURE as _
71            }
72            Self::InvalidInput => sys::nvrtcResult::NVRTC_ERROR_INVALID_INPUT as _,
73            Self::InvalidProgram => sys::nvrtcResult::NVRTC_ERROR_INVALID_PROGRAM as _,
74            Self::InvalidOption => sys::nvrtcResult::NVRTC_ERROR_INVALID_OPTION as _,
75            Self::Compilation => sys::nvrtcResult::NVRTC_ERROR_COMPILATION as _,
76            Self::BuiltinOperationFailure => {
77                sys::nvrtcResult::NVRTC_ERROR_BUILTIN_OPERATION_FAILURE as _
78            }
79            Self::NoNameExpressionsAfterCompilation => {
80                sys::nvrtcResult::NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION as _
81            }
82            Self::NoLoweredNamesBeforeCompilation => {
83                sys::nvrtcResult::NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION as _
84            }
85            Self::NameExpressionNotValid => {
86                sys::nvrtcResult::NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID as _
87            }
88            Self::InternalError => sys::nvrtcResult::NVRTC_ERROR_INTERNAL_ERROR as _,
89            Self::TimeFileWriteFailed => sys::nvrtcResult::NVRTC_ERROR_TIME_FILE_WRITE_FAILED as _,
90            Self::NoPchCreateAttempted => {
91                sys::nvrtcResult::NVRTC_ERROR_NO_PCH_CREATE_ATTEMPTED as _
92            }
93            Self::PchCreateHeapExhausted => {
94                sys::nvrtcResult::NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED as _
95            }
96            Self::PchCreate => sys::nvrtcResult::NVRTC_ERROR_PCH_CREATE as _,
97            Self::Cancelled => sys::nvrtcResult::NVRTC_ERROR_CANCELLED as _,
98            Self::TimeTraceFileWriteFailed => {
99                sys::nvrtcResult::NVRTC_ERROR_TIME_TRACE_FILE_WRITE_FAILED as _
100            }
101            Self::Unknown(code) => code,
102        }
103    }
104}
105
106impl TryFrom<u32> for Status {
107    type Error = u32;
108
109    fn try_from(code: u32) -> result::Result<Self, u32> {
110        match code {
111            code if code == sys::nvrtcResult::NVRTC_SUCCESS as u32 => Ok(Self::Success),
112            code if code == sys::nvrtcResult::NVRTC_ERROR_OUT_OF_MEMORY as u32 => {
113                Ok(Self::OutOfMemory)
114            }
115            code if code == sys::nvrtcResult::NVRTC_ERROR_PROGRAM_CREATION_FAILURE as u32 => {
116                Ok(Self::ProgramCreationFailure)
117            }
118            code if code == sys::nvrtcResult::NVRTC_ERROR_INVALID_INPUT as u32 => {
119                Ok(Self::InvalidInput)
120            }
121            code if code == sys::nvrtcResult::NVRTC_ERROR_INVALID_PROGRAM as u32 => {
122                Ok(Self::InvalidProgram)
123            }
124            code if code == sys::nvrtcResult::NVRTC_ERROR_INVALID_OPTION as u32 => {
125                Ok(Self::InvalidOption)
126            }
127            code if code == sys::nvrtcResult::NVRTC_ERROR_COMPILATION as u32 => {
128                Ok(Self::Compilation)
129            }
130            code if code == sys::nvrtcResult::NVRTC_ERROR_BUILTIN_OPERATION_FAILURE as u32 => {
131                Ok(Self::BuiltinOperationFailure)
132            }
133            code if code
134                == sys::nvrtcResult::NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION as u32 =>
135            {
136                Ok(Self::NoNameExpressionsAfterCompilation)
137            }
138            code if code
139                == sys::nvrtcResult::NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION as u32 =>
140            {
141                Ok(Self::NoLoweredNamesBeforeCompilation)
142            }
143            code if code == sys::nvrtcResult::NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID as u32 => {
144                Ok(Self::NameExpressionNotValid)
145            }
146            code if code == sys::nvrtcResult::NVRTC_ERROR_INTERNAL_ERROR as u32 => {
147                Ok(Self::InternalError)
148            }
149            code if code == sys::nvrtcResult::NVRTC_ERROR_TIME_FILE_WRITE_FAILED as u32 => {
150                Ok(Self::TimeFileWriteFailed)
151            }
152            code if code == sys::nvrtcResult::NVRTC_ERROR_NO_PCH_CREATE_ATTEMPTED as u32 => {
153                Ok(Self::NoPchCreateAttempted)
154            }
155            code if code == sys::nvrtcResult::NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED as u32 => {
156                Ok(Self::PchCreateHeapExhausted)
157            }
158            code if code == sys::nvrtcResult::NVRTC_ERROR_PCH_CREATE as u32 => Ok(Self::PchCreate),
159            code if code == sys::nvrtcResult::NVRTC_ERROR_CANCELLED as u32 => Ok(Self::Cancelled),
160            code if code == sys::nvrtcResult::NVRTC_ERROR_TIME_TRACE_FILE_WRITE_FAILED as u32 => {
161                Ok(Self::TimeTraceFileWriteFailed)
162            }
163            code => Err(code),
164        }
165    }
166}
167
168impl From<sys::nvrtcResult> for Status {
169    fn from(status: sys::nvrtcResult) -> Self {
170        Self::try_from(status as u32).unwrap_or_else(Self::Unknown)
171    }
172}
173
174impl TryFrom<Status> for sys::nvrtcResult {
175    type Error = Status;
176
177    fn try_from(status: Status) -> result::Result<Self, Status> {
178        match status {
179            Status::Success => Ok(Self::NVRTC_SUCCESS),
180            Status::OutOfMemory => Ok(Self::NVRTC_ERROR_OUT_OF_MEMORY),
181            Status::ProgramCreationFailure => Ok(Self::NVRTC_ERROR_PROGRAM_CREATION_FAILURE),
182            Status::InvalidInput => Ok(Self::NVRTC_ERROR_INVALID_INPUT),
183            Status::InvalidProgram => Ok(Self::NVRTC_ERROR_INVALID_PROGRAM),
184            Status::InvalidOption => Ok(Self::NVRTC_ERROR_INVALID_OPTION),
185            Status::Compilation => Ok(Self::NVRTC_ERROR_COMPILATION),
186            Status::BuiltinOperationFailure => Ok(Self::NVRTC_ERROR_BUILTIN_OPERATION_FAILURE),
187            Status::NoNameExpressionsAfterCompilation => {
188                Ok(Self::NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION)
189            }
190            Status::NoLoweredNamesBeforeCompilation => {
191                Ok(Self::NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION)
192            }
193            Status::NameExpressionNotValid => Ok(Self::NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID),
194            Status::InternalError => Ok(Self::NVRTC_ERROR_INTERNAL_ERROR),
195            Status::TimeFileWriteFailed => Ok(Self::NVRTC_ERROR_TIME_FILE_WRITE_FAILED),
196            Status::NoPchCreateAttempted => Ok(Self::NVRTC_ERROR_NO_PCH_CREATE_ATTEMPTED),
197            Status::PchCreateHeapExhausted => Ok(Self::NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED),
198            Status::PchCreate => Ok(Self::NVRTC_ERROR_PCH_CREATE),
199            Status::Cancelled => Ok(Self::NVRTC_ERROR_CANCELLED),
200            Status::TimeTraceFileWriteFailed => Ok(Self::NVRTC_ERROR_TIME_TRACE_FILE_WRITE_FAILED),
201            Status::Unknown(_) => Err(status),
202        }
203    }
204}
205
206impl Display for Status {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        match self {
209            Self::Success => f.write_str("NVRTC_SUCCESS"),
210            Self::OutOfMemory => f.write_str("NVRTC_ERROR_OUT_OF_MEMORY"),
211            Self::ProgramCreationFailure => f.write_str("NVRTC_ERROR_PROGRAM_CREATION_FAILURE"),
212            Self::InvalidInput => f.write_str("NVRTC_ERROR_INVALID_INPUT"),
213            Self::InvalidProgram => f.write_str("NVRTC_ERROR_INVALID_PROGRAM"),
214            Self::InvalidOption => f.write_str("NVRTC_ERROR_INVALID_OPTION"),
215            Self::Compilation => f.write_str("NVRTC_ERROR_COMPILATION"),
216            Self::BuiltinOperationFailure => f.write_str("NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"),
217            Self::NoNameExpressionsAfterCompilation => {
218                f.write_str("NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION")
219            }
220            Self::NoLoweredNamesBeforeCompilation => {
221                f.write_str("NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION")
222            }
223            Self::NameExpressionNotValid => f.write_str("NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID"),
224            Self::InternalError => f.write_str("NVRTC_ERROR_INTERNAL_ERROR"),
225            Self::TimeFileWriteFailed => f.write_str("NVRTC_ERROR_TIME_FILE_WRITE_FAILED"),
226            Self::NoPchCreateAttempted => f.write_str("NVRTC_ERROR_NO_PCH_CREATE_ATTEMPTED"),
227            Self::PchCreateHeapExhausted => f.write_str("NVRTC_ERROR_PCH_CREATE_HEAP_EXHAUSTED"),
228            Self::PchCreate => f.write_str("NVRTC_ERROR_PCH_CREATE"),
229            Self::Cancelled => f.write_str("NVRTC_ERROR_CANCELLED"),
230            Self::TimeTraceFileWriteFailed => {
231                f.write_str("NVRTC_ERROR_TIME_TRACE_FILE_WRITE_FAILED")
232            }
233            Self::Unknown(code) => write!(f, "UNKNOWN_NVRTC_STATUS({code})"),
234        }
235    }
236}
237
238#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
239pub struct Header<'a> {
240    pub source: &'a str,
241    pub include_name: &'a str,
242}
243
244#[derive(Debug, Clone)]
245struct OwnedHeader {
246    source: String,
247    include_name: String,
248}
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
251#[non_exhaustive]
252pub enum MacroDefinition<'a> {
253    Name(&'a str),
254    WithValue { name: &'a str, value: &'a str },
255}
256
257impl MacroDefinition<'_> {
258    fn format(self) -> String {
259        match self {
260            Self::Name(name) => name.to_string(),
261            Self::WithValue { name, value } => format!("{name}={value}"),
262        }
263    }
264}
265
266#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
267#[non_exhaustive]
268pub enum CppDialect {
269    Cpp03,
270    Cpp11,
271    Cpp14,
272    Cpp17,
273    Cpp20,
274}
275
276impl_enum_display!(CppDialect, {
277    Self::Cpp03 => "c++03",
278    Self::Cpp11 => "c++11",
279    Self::Cpp14 => "c++14",
280    Self::Cpp17 => "c++17",
281    Self::Cpp20 => "c++20",
282});
283
284#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
285#[non_exhaustive]
286pub enum FastCompileLevel {
287    Zero,
288    Min,
289    Mid,
290    Max,
291}
292
293impl_enum_display!(FastCompileLevel, {
294    Self::Zero => "0",
295    Self::Min => "min",
296    Self::Mid => "mid",
297    Self::Max => "max",
298});
299
300#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
301#[non_exhaustive]
302pub enum WarningAsErrorKind {
303    AllWarnings,
304    Reorder,
305    DeprecatedDeclarations,
306}
307
308impl_enum_display!(WarningAsErrorKind, {
309    Self::AllWarnings => "all-warnings",
310    Self::Reorder => "reorder",
311    Self::DeprecatedDeclarations => "deprecated-declarations",
312});
313
314#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
315#[non_exhaustive]
316pub enum OptimizationInfoKind {
317    Inline,
318}
319
320impl_enum_display!(OptimizationInfoKind, {
321    Self::Inline => "inline",
322});
323
324#[derive(Debug, Clone, Default)]
325pub struct CompileOptions<'a> {
326    pub gpu_architecture: Option<GpuArchitecture>,
327    pub relocatable_device_code: Option<bool>,
328    pub extensible_whole_program: bool,
329    pub device_debug: bool,
330    pub generate_line_info: bool,
331    pub device_optimization: Option<bool>,
332    pub fast_compile: Option<FastCompileLevel>,
333    pub ptxas_options: Vec<&'a str>,
334    pub max_register_count: Option<i32>,
335    pub flush_to_zero: Option<bool>,
336    pub precise_square_root: Option<bool>,
337    pub precise_division: Option<bool>,
338    pub fmad: Option<bool>,
339    pub use_fast_math: bool,
340    pub extra_device_vectorization: bool,
341    pub modify_stack_limit: Option<bool>,
342    pub dlink_time_optimization: bool,
343    pub generate_optimized_lto: bool,
344    pub optix_ir: bool,
345    pub jump_table_density: Option<u8>,
346    pub no_cache: bool,
347    pub random_seed: Option<&'a str>,
348    pub define_macros: Vec<MacroDefinition<'a>>,
349    pub undefine_macros: Vec<&'a str>,
350    pub include_paths: Vec<&'a str>,
351    pub pre_include_headers: Vec<&'a str>,
352    pub no_source_include: bool,
353    pub cpp_dialect: Option<CppDialect>,
354    pub builtin_move_forward: Option<bool>,
355    pub builtin_initializer_list: Option<bool>,
356    pub pch: bool,
357    pub create_pch: Option<&'a str>,
358    pub use_pch: Option<&'a str>,
359    pub pch_dir: Option<&'a str>,
360    pub pch_verbose: Option<bool>,
361    pub pch_messages: Option<bool>,
362    pub instantiate_templates_in_pch: Option<bool>,
363    pub disable_warnings: bool,
364    pub warning_as_error: Vec<WarningAsErrorKind>,
365    pub restrict_pointers: bool,
366    pub device_as_default_execution_space: bool,
367    pub device_int128: bool,
368    pub device_float128: bool,
369    pub optimization_info: Vec<OptimizationInfoKind>,
370    pub display_error_number: Option<bool>,
371    pub diag_error: Vec<i32>,
372    pub diag_suppress: Vec<i32>,
373    pub diag_warn: Vec<i32>,
374    pub brief_diagnostics: Option<bool>,
375    pub time: Option<&'a str>,
376    pub split_compile: Option<i32>,
377    pub device_syntax_only: bool,
378    pub minimal: bool,
379    pub device_stack_protector: Option<bool>,
380    pub device_time_trace: Option<&'a str>,
381    pub raw_options: Vec<&'a str>,
382}
383
384impl<'a> CompileOptions<'a> {
385    pub const fn new() -> Self {
386        Self {
387            gpu_architecture: None,
388            relocatable_device_code: None,
389            extensible_whole_program: false,
390            device_debug: false,
391            generate_line_info: false,
392            device_optimization: None,
393            fast_compile: None,
394            ptxas_options: Vec::new(),
395            max_register_count: None,
396            flush_to_zero: None,
397            precise_square_root: None,
398            precise_division: None,
399            fmad: None,
400            use_fast_math: false,
401            extra_device_vectorization: false,
402            modify_stack_limit: None,
403            dlink_time_optimization: false,
404            generate_optimized_lto: false,
405            optix_ir: false,
406            jump_table_density: None,
407            no_cache: false,
408            random_seed: None,
409            define_macros: Vec::new(),
410            undefine_macros: Vec::new(),
411            include_paths: Vec::new(),
412            pre_include_headers: Vec::new(),
413            no_source_include: false,
414            cpp_dialect: None,
415            builtin_move_forward: None,
416            builtin_initializer_list: None,
417            pch: false,
418            create_pch: None,
419            use_pch: None,
420            pch_dir: None,
421            pch_verbose: None,
422            pch_messages: None,
423            instantiate_templates_in_pch: None,
424            disable_warnings: false,
425            warning_as_error: Vec::new(),
426            restrict_pointers: false,
427            device_as_default_execution_space: false,
428            device_int128: false,
429            device_float128: false,
430            optimization_info: Vec::new(),
431            display_error_number: None,
432            diag_error: Vec::new(),
433            diag_suppress: Vec::new(),
434            diag_warn: Vec::new(),
435            brief_diagnostics: None,
436            time: None,
437            split_compile: None,
438            device_syntax_only: false,
439            minimal: false,
440            device_stack_protector: None,
441            device_time_trace: None,
442            raw_options: Vec::new(),
443        }
444    }
445
446    pub fn gpu_architecture(mut self, value: GpuArchitecture) -> Self {
447        self.gpu_architecture = Some(value);
448        self
449    }
450
451    pub fn relocatable_device_code(mut self, value: bool) -> Self {
452        self.relocatable_device_code = Some(value);
453        self
454    }
455
456    pub fn extensible_whole_program(mut self, value: bool) -> Self {
457        self.extensible_whole_program = value;
458        self
459    }
460
461    pub fn device_debug(mut self, value: bool) -> Self {
462        self.device_debug = value;
463        self
464    }
465
466    pub fn generate_line_info(mut self, value: bool) -> Self {
467        self.generate_line_info = value;
468        self
469    }
470
471    pub fn device_optimization(mut self, value: bool) -> Self {
472        self.device_optimization = Some(value);
473        self
474    }
475
476    pub fn fast_compile(mut self, value: FastCompileLevel) -> Self {
477        self.fast_compile = Some(value);
478        self
479    }
480
481    pub fn ptxas_option(mut self, value: &'a str) -> Self {
482        self.ptxas_options.push(value);
483        self
484    }
485
486    pub fn max_register_count(mut self, value: i32) -> Self {
487        self.max_register_count = Some(value);
488        self
489    }
490
491    pub fn flush_to_zero(mut self, value: bool) -> Self {
492        self.flush_to_zero = Some(value);
493        self
494    }
495
496    pub fn precise_square_root(mut self, value: bool) -> Self {
497        self.precise_square_root = Some(value);
498        self
499    }
500
501    pub fn precise_division(mut self, value: bool) -> Self {
502        self.precise_division = Some(value);
503        self
504    }
505
506    pub fn fmad(mut self, value: bool) -> Self {
507        self.fmad = Some(value);
508        self
509    }
510
511    pub fn use_fast_math(mut self, value: bool) -> Self {
512        self.use_fast_math = value;
513        self
514    }
515
516    pub fn extra_device_vectorization(mut self, value: bool) -> Self {
517        self.extra_device_vectorization = value;
518        self
519    }
520
521    pub fn modify_stack_limit(mut self, value: bool) -> Self {
522        self.modify_stack_limit = Some(value);
523        self
524    }
525
526    pub fn dlink_time_optimization(mut self, value: bool) -> Self {
527        self.dlink_time_optimization = value;
528        self
529    }
530
531    pub fn generate_optimized_lto(mut self, value: bool) -> Self {
532        self.generate_optimized_lto = value;
533        self
534    }
535
536    pub fn optix_ir(mut self, value: bool) -> Self {
537        self.optix_ir = value;
538        self
539    }
540
541    pub fn jump_table_density(mut self, value: u8) -> Self {
542        self.jump_table_density = Some(value.min(101));
543        self
544    }
545
546    pub fn no_cache(mut self, value: bool) -> Self {
547        self.no_cache = value;
548        self
549    }
550
551    pub fn random_seed(mut self, value: &'a str) -> Self {
552        self.random_seed = Some(value);
553        self
554    }
555
556    pub fn define_macro(mut self, value: MacroDefinition<'a>) -> Self {
557        self.define_macros.push(value);
558        self
559    }
560
561    pub fn undefine_macro(mut self, value: &'a str) -> Self {
562        self.undefine_macros.push(value);
563        self
564    }
565
566    pub fn include_path(mut self, value: &'a str) -> Self {
567        self.include_paths.push(value);
568        self
569    }
570
571    pub fn pre_include_header(mut self, value: &'a str) -> Self {
572        self.pre_include_headers.push(value);
573        self
574    }
575
576    pub fn no_source_include(mut self, value: bool) -> Self {
577        self.no_source_include = value;
578        self
579    }
580
581    pub fn cpp_dialect(mut self, value: CppDialect) -> Self {
582        self.cpp_dialect = Some(value);
583        self
584    }
585
586    pub fn builtin_move_forward(mut self, value: bool) -> Self {
587        self.builtin_move_forward = Some(value);
588        self
589    }
590
591    pub fn builtin_initializer_list(mut self, value: bool) -> Self {
592        self.builtin_initializer_list = Some(value);
593        self
594    }
595
596    pub fn pch(mut self, value: bool) -> Self {
597        self.pch = value;
598        self
599    }
600
601    pub fn create_pch(mut self, value: &'a str) -> Self {
602        self.create_pch = Some(value);
603        self
604    }
605
606    pub fn use_pch(mut self, value: &'a str) -> Self {
607        self.use_pch = Some(value);
608        self
609    }
610
611    pub fn pch_dir(mut self, value: &'a str) -> Self {
612        self.pch_dir = Some(value);
613        self
614    }
615
616    pub fn pch_verbose(mut self, value: bool) -> Self {
617        self.pch_verbose = Some(value);
618        self
619    }
620
621    pub fn pch_messages(mut self, value: bool) -> Self {
622        self.pch_messages = Some(value);
623        self
624    }
625
626    pub fn instantiate_templates_in_pch(mut self, value: bool) -> Self {
627        self.instantiate_templates_in_pch = Some(value);
628        self
629    }
630
631    pub fn disable_warnings(mut self, value: bool) -> Self {
632        self.disable_warnings = value;
633        self
634    }
635
636    pub fn warning_as_error(mut self, value: WarningAsErrorKind) -> Self {
637        self.warning_as_error.push(value);
638        self
639    }
640
641    pub fn restrict_pointers(mut self, value: bool) -> Self {
642        self.restrict_pointers = value;
643        self
644    }
645
646    pub fn device_as_default_execution_space(mut self, value: bool) -> Self {
647        self.device_as_default_execution_space = value;
648        self
649    }
650
651    pub fn device_int128(mut self, value: bool) -> Self {
652        self.device_int128 = value;
653        self
654    }
655
656    pub fn device_float128(mut self, value: bool) -> Self {
657        self.device_float128 = value;
658        self
659    }
660
661    pub fn optimization_info(mut self, value: OptimizationInfoKind) -> Self {
662        self.optimization_info.push(value);
663        self
664    }
665
666    pub fn display_error_number(mut self, value: bool) -> Self {
667        self.display_error_number = Some(value);
668        self
669    }
670
671    pub fn diag_error(mut self, value: i32) -> Self {
672        self.diag_error.push(value);
673        self
674    }
675
676    pub fn diag_suppress(mut self, value: i32) -> Self {
677        self.diag_suppress.push(value);
678        self
679    }
680
681    pub fn diag_warn(mut self, value: i32) -> Self {
682        self.diag_warn.push(value);
683        self
684    }
685
686    pub fn brief_diagnostics(mut self, value: bool) -> Self {
687        self.brief_diagnostics = Some(value);
688        self
689    }
690
691    pub fn time(mut self, value: &'a str) -> Self {
692        self.time = Some(value);
693        self
694    }
695
696    pub fn split_compile(mut self, value: i32) -> Self {
697        self.split_compile = Some(value);
698        self
699    }
700
701    pub fn device_syntax_only(mut self, value: bool) -> Self {
702        self.device_syntax_only = value;
703        self
704    }
705
706    pub fn minimal(mut self, value: bool) -> Self {
707        self.minimal = value;
708        self
709    }
710
711    pub fn device_stack_protector(mut self, value: bool) -> Self {
712        self.device_stack_protector = Some(value);
713        self
714    }
715
716    pub fn device_time_trace(mut self, value: &'a str) -> Self {
717        self.device_time_trace = Some(value);
718        self
719    }
720
721    pub fn raw_option(mut self, value: &'a str) -> Self {
722        self.raw_options.push(value);
723        self
724    }
725
726    pub fn as_arguments(&self) -> Vec<String> {
727        let mut arguments = Vec::new();
728
729        if let Some(value) = self.gpu_architecture {
730            arguments.push(format!("--gpu-architecture={value}"));
731        }
732        if let Some(value) = self.relocatable_device_code {
733            arguments.push(format!("--relocatable-device-code={}", bool_flag(value)));
734        }
735        if self.extensible_whole_program {
736            arguments.push(String::from("--extensible-whole-program"));
737        }
738        if self.device_debug {
739            arguments.push(String::from("--device-debug"));
740        }
741        if self.generate_line_info {
742            arguments.push(String::from("--generate-line-info"));
743        }
744        if let Some(value) = self.device_optimization
745            && value
746        {
747            arguments.push(String::from("--dopt=on"));
748        }
749        if let Some(value) = self.fast_compile {
750            arguments.push(format!("--Ofast-compile={value}"));
751        }
752        arguments.extend(
753            self.ptxas_options
754                .iter()
755                .map(|value| format!("--ptxas-options={value}")),
756        );
757        if let Some(value) = self.max_register_count {
758            arguments.push(format!("--maxrregcount={value}"));
759        }
760        if let Some(value) = self.flush_to_zero {
761            arguments.push(format!("--ftz={}", bool_flag(value)));
762        }
763        if let Some(value) = self.precise_square_root {
764            arguments.push(format!("--prec-sqrt={}", bool_flag(value)));
765        }
766        if let Some(value) = self.precise_division {
767            arguments.push(format!("--prec-div={}", bool_flag(value)));
768        }
769        if let Some(value) = self.fmad {
770            arguments.push(format!("--fmad={}", bool_flag(value)));
771        }
772        if self.use_fast_math {
773            arguments.push(String::from("--use_fast_math"));
774        }
775        if self.extra_device_vectorization {
776            arguments.push(String::from("--extra-device-vectorization"));
777        }
778        if let Some(value) = self.modify_stack_limit {
779            arguments.push(format!("--modify-stack-limit={}", bool_flag(value)));
780        }
781        if self.dlink_time_optimization {
782            arguments.push(String::from("--dlink-time-opt"));
783        }
784        if self.generate_optimized_lto {
785            arguments.push(String::from("--gen-opt-lto"));
786        }
787        if self.optix_ir {
788            arguments.push(String::from("--optix-ir"));
789        }
790        if let Some(value) = self.jump_table_density {
791            arguments.push(format!("--jump-table-density={value}"));
792        }
793        if self.no_cache {
794            arguments.push(String::from("--no-cache"));
795        }
796        if let Some(value) = self.random_seed {
797            arguments.push(format!("--frandom-seed={value}"));
798        }
799        arguments.extend(
800            self.define_macros
801                .iter()
802                .copied()
803                .map(|value| format!("--define-macro={}", value.format())),
804        );
805        arguments.extend(
806            self.undefine_macros
807                .iter()
808                .map(|value| format!("--undefine-macro={value}")),
809        );
810        arguments.extend(
811            self.include_paths
812                .iter()
813                .map(|value| format!("--include-path={value}")),
814        );
815        arguments.extend(
816            self.pre_include_headers
817                .iter()
818                .map(|value| format!("--pre-include={value}")),
819        );
820        if self.no_source_include {
821            arguments.push(String::from("--no-source-include"));
822        }
823        if let Some(value) = self.cpp_dialect {
824            arguments.push(format!("--std={value}"));
825        }
826        if let Some(value) = self.builtin_move_forward {
827            arguments.push(format!("--builtin-move-forward={}", bool_flag(value)));
828        }
829        if let Some(value) = self.builtin_initializer_list {
830            arguments.push(format!("--builtin-initializer-list={}", bool_flag(value)));
831        }
832        if self.pch {
833            arguments.push(String::from("--pch"));
834        }
835        if let Some(value) = self.create_pch {
836            arguments.push(format!("--create-pch={value}"));
837        }
838        if let Some(value) = self.use_pch {
839            arguments.push(format!("--use-pch={value}"));
840        }
841        if let Some(value) = self.pch_dir {
842            arguments.push(format!("--pch-dir={value}"));
843        }
844        if let Some(value) = self.pch_verbose {
845            arguments.push(format!("--pch-verbose={}", bool_flag(value)));
846        }
847        if let Some(value) = self.pch_messages {
848            arguments.push(format!("--pch-messages={}", bool_flag(value)));
849        }
850        if let Some(value) = self.instantiate_templates_in_pch {
851            arguments.push(format!(
852                "--instantiate-templates-in-pch={}",
853                bool_flag(value)
854            ));
855        }
856        if self.disable_warnings {
857            arguments.push(String::from("--disable-warnings"));
858        }
859        if !self.warning_as_error.is_empty() {
860            arguments.push(format!(
861                "--warning-as-error={}",
862                join_display(&self.warning_as_error)
863            ));
864        }
865        if self.restrict_pointers {
866            arguments.push(String::from("--restrict"));
867        }
868        if self.device_as_default_execution_space {
869            arguments.push(String::from("--device-as-default-execution-space"));
870        }
871        if self.device_int128 {
872            arguments.push(String::from("--device-int128"));
873        }
874        if self.device_float128 {
875            arguments.push(String::from("--device-float128"));
876        }
877        arguments.extend(
878            self.optimization_info
879                .iter()
880                .map(|value| format!("--optimization-info={value}")),
881        );
882        if let Some(value) = self.display_error_number {
883            arguments.push(if value {
884                String::from("--display-error-number")
885            } else {
886                String::from("--no-display-error-number")
887            });
888        }
889        if !self.diag_error.is_empty() {
890            arguments.push(format!("--diag-error={}", join_numbers(&self.diag_error)));
891        }
892        if !self.diag_suppress.is_empty() {
893            arguments.push(format!(
894                "--diag-suppress={}",
895                join_numbers(&self.diag_suppress)
896            ));
897        }
898        if !self.diag_warn.is_empty() {
899            arguments.push(format!("--diag-warn={}", join_numbers(&self.diag_warn)));
900        }
901        if let Some(value) = self.brief_diagnostics {
902            arguments.push(format!("--brief-diagnostics={}", bool_flag(value)));
903        }
904        if let Some(value) = self.time {
905            arguments.push(format!("--time={value}"));
906        }
907        if let Some(value) = self.split_compile {
908            arguments.push(format!("--split-compile={value}"));
909        }
910        if self.device_syntax_only {
911            arguments.push(String::from("--fdevice-syntax-only"));
912        }
913        if self.minimal {
914            arguments.push(String::from("--minimal"));
915        }
916        if let Some(value) = self.device_stack_protector {
917            arguments.push(format!("--device-stack-protector={}", bool_flag(value)));
918        }
919        if let Some(value) = self.device_time_trace {
920            arguments.push(format!("--fdevice-time-trace={value}"));
921        }
922
923        arguments.extend(self.raw_options.iter().map(|value| (*value).to_string()));
924        arguments
925    }
926}
927
928#[derive(Debug)]
929pub struct Program {
930    source: String,
931    name: Option<String>,
932    headers: Vec<OwnedHeader>,
933    handle: UnsafeCell<sys::nvrtcProgram>,
934}
935
936impl Program {
937    /// Creates a lazy NVRTC program from CUDA source text.
938    ///
939    /// The raw NVRTC handle is created when the program is compiled or
940    /// otherwise materialized.
941    pub fn from_source(source: &str) -> Self {
942        Self {
943            source: source.to_string(),
944            name: None,
945            headers: Vec::new(),
946            handle: UnsafeCell::new(ptr::null_mut()),
947        }
948    }
949
950    pub fn with_name(mut self, name: &str) -> Self {
951        self.name = Some(name.to_string());
952        self
953    }
954
955    pub fn with_header(mut self, header: Header<'_>) -> Self {
956        self.headers.push(OwnedHeader {
957            source: header.source.to_string(),
958            include_name: header.include_name.to_string(),
959        });
960        self
961    }
962
963    pub fn with_headers(mut self, headers: &[Header<'_>]) -> Self {
964        self.headers
965            .extend(headers.iter().map(|header| OwnedHeader {
966                source: header.source.to_string(),
967                include_name: header.include_name.to_string(),
968            }));
969        self
970    }
971
972    /// Takes ownership of a raw NVRTC program handle.
973    ///
974    /// # Safety
975    ///
976    /// `handle` must be a valid `nvrtcProgram` that is not owned by any other
977    /// wrapper. The returned wrapper destroys it with `nvrtcDestroyProgram`.
978    pub unsafe fn from_raw(handle: sys::nvrtcProgram) -> Result<Self> {
979        if handle.is_null() {
980            return Err(Error::NullHandle);
981        }
982
983        Ok(Self {
984            source: String::new(),
985            name: None,
986            headers: Vec::new(),
987            handle: UnsafeCell::new(handle),
988        })
989    }
990
991    pub fn compile(&self, options: &[&str]) -> Result<()> {
992        self.compile_raw(options)
993    }
994
995    pub fn compile_with_options(&self, options: &CompileOptions<'_>) -> Result<()> {
996        let arguments = options.as_arguments();
997        let argument_refs = arguments.iter().map(String::as_str).collect::<Vec<_>>();
998        self.compile_raw(&argument_refs)
999    }
1000
1001    /// Registers a flow callback that can cancel compilation.
1002    ///
1003    /// The callback function must satisfy the following constraints:
1004    ///
1005    /// (1) It must return 1 to cancel compilation or 0 to continue.
1006    /// Other return values are reserved for future use.
1007    ///
1008    /// (2) It must return consistent values.
1009    /// Once it returns 1 at one point, it must return 1 in all following invocations during the current nvrtcCompileProgram call in progress.
1010    ///
1011    /// (3) It must be thread-safe.
1012    ///
1013    /// (4) It must not invoke any NVRTC, libNVVM, or PTX APIs.
1014    pub fn compile_with_options_and_cancel_flag(
1015        &self,
1016        options: &CompileOptions<'_>,
1017        cancel: &AtomicBool,
1018    ) -> Result<()> {
1019        unsafe {
1020            try_nvrtc!(sys::nvrtcSetFlowCallback(
1021                self.handle()?,
1022                Some(cancel_if_requested_callback),
1023                ptr::from_ref(cancel).cast_mut().cast(),
1024            ))?;
1025        }
1026
1027        let compile_result = self.compile_with_options(options);
1028        let clear_result = clear_flow_callback(self);
1029
1030        match (compile_result, clear_result) {
1031            (Err(error), _) => Err(error),
1032            (Ok(()), Err(error)) => Err(error),
1033            (Ok(()), Ok(())) => Ok(()),
1034        }
1035    }
1036
1037    /// Registers a name expression for a `__global__`, `__device__`, or `__constant__` symbol.
1038    ///
1039    /// The identical name expression string must be provided to [`Program::lowered_name`] after compilation.
1040    ///
1041    /// # Errors
1042    ///
1043    /// Returns an error if `name_expression` contains an interior NUL byte, if
1044    /// the program handle is invalid, or if NVRTC rejects the expression.
1045    pub fn add_name_expression(&self, name_expression: &str) -> Result<()> {
1046        let name_expression = CString::new(name_expression)?;
1047        unsafe {
1048            try_nvrtc!(sys::nvrtcAddNameExpression(
1049                self.handle()?,
1050                name_expression.as_ptr(),
1051            ))
1052        }
1053    }
1054
1055    /// Returns the lowered (mangled) name for a registered name expression.
1056    ///
1057    /// The identical name expression must have been previously provided to [`Program::add_name_expression`].
1058    ///
1059    /// # Errors
1060    ///
1061    /// Returns an error if `name_expression` contains an interior NUL byte, if
1062    /// the program handle is invalid, or if NVRTC cannot return the lowered name.
1063    pub fn lowered_name(&self, name_expression: &str) -> Result<String> {
1064        let name_expression = CString::new(name_expression)?;
1065        let mut lowered_name = ptr::null();
1066        unsafe {
1067            try_nvrtc!(sys::nvrtcGetLoweredName(
1068                self.handle()?,
1069                name_expression.as_ptr(),
1070                &raw mut lowered_name,
1071            ))?;
1072            Ok(CStr::from_ptr(lowered_name).to_string_lossy().into_owned())
1073        }
1074    }
1075
1076    pub fn ptx(&self) -> Result<Vec<u8>> {
1077        self.bytes(sys::nvrtcGetPTXSize, sys::nvrtcGetPTX)
1078    }
1079
1080    pub fn ptx_image(&self) -> Result<ModuleImage<'static>> {
1081        Ok(ModuleImage::from_vec(self.ptx()?))
1082    }
1083
1084    pub fn ptx_string(&self) -> Result<String> {
1085        Ok(bytes_to_string(self.ptx()?))
1086    }
1087
1088    pub fn cubin(&self) -> Result<Vec<u8>> {
1089        self.bytes(sys::nvrtcGetCUBINSize, sys::nvrtcGetCUBIN)
1090    }
1091
1092    pub fn cubin_image(&self) -> Result<ModuleImage<'static>> {
1093        Ok(ModuleImage::from_vec(self.cubin()?))
1094    }
1095
1096    pub fn lto_ir(&self) -> Result<Vec<u8>> {
1097        self.bytes(sys::nvrtcGetLTOIRSize, sys::nvrtcGetLTOIR)
1098    }
1099
1100    pub fn lto_ir_image(&self) -> Result<ModuleImage<'static>> {
1101        Ok(ModuleImage::from_vec(self.lto_ir()?))
1102    }
1103
1104    pub fn optix_ir(&self) -> Result<Vec<u8>> {
1105        self.bytes(sys::nvrtcGetOptiXIRSize, sys::nvrtcGetOptiXIR)
1106    }
1107
1108    pub fn optix_ir_image(&self) -> Result<ModuleImage<'static>> {
1109        Ok(ModuleImage::from_vec(self.optix_ir()?))
1110    }
1111
1112    pub fn log(&self) -> Result<String> {
1113        Ok(bytes_to_string(self.bytes(
1114            sys::nvrtcGetProgramLogSize,
1115            sys::nvrtcGetProgramLog,
1116        )?))
1117    }
1118
1119    /// Returns the PCH creation status.
1120    ///
1121    /// [`Status::Success`] indicates that the PCH was successfully created.
1122    /// [`Status::NoPchCreateAttempted`] indicates that no PCH creation was attempted, either because PCH was not requested during the preceding compile call, or automatic PCH processing was requested, and the compiler chose not to create a PCH file.
1123    /// [`Status::PchCreateHeapExhausted`] indicates that a PCH file could potentially have been created, but the compiler ran out space in the PCH heap.
1124    /// In this scenario, use [`Program::pch_heap_size_required`] to query the required heap size, reallocate the heap with [`set_pch_heap_size`], and retry PCH creation with [`sys::nvrtcCompileProgram`](singe_cuda_sys::nvrtc::nvrtcCompileProgram) on a new NVRTC program instance.
1125    /// [`Status::PchCreate`] indicates that an error condition prevented the PCH file from being created.
1126    ///
1127    /// # Errors
1128    ///
1129    /// Returns an error if the program handle is invalid.
1130    pub fn pch_create_status(&self) -> Result<Status> {
1131        unsafe { Ok(sys::nvrtcGetPCHCreateStatus(self.handle()?).into()) }
1132    }
1133
1134    /// Returns the PCH heap size required to compile the given program.
1135    ///
1136    /// # Errors
1137    ///
1138    /// Returns an error if the program handle is invalid or if the returned size
1139    /// is not valid because [`Program::pch_create_status`] did not return
1140    /// [`Status::Success`] or [`Status::PchCreateHeapExhausted`].
1141    pub fn pch_heap_size_required(&self) -> Result<usize> {
1142        let mut size = 0;
1143        unsafe {
1144            try_nvrtc!(sys::nvrtcGetPCHHeapSizeRequired(
1145                self.handle()?,
1146                &raw mut size
1147            ))?;
1148        }
1149        Ok(size as usize)
1150    }
1151
1152    /// Returns the raw NVRTC program handle, creating it first if needed.
1153    ///
1154    /// If the program has not been materialized yet, this creates the raw
1155    /// handle first.
1156    ///
1157    /// # Errors
1158    ///
1159    /// Returns an error if NVRTC cannot create the program handle.
1160    pub fn materialize_raw(&self) -> Result<sys::nvrtcProgram> {
1161        self.handle()
1162    }
1163
1164    /// Returns the raw NVRTC program handle only if it has already been
1165    /// materialized.
1166    pub const fn as_raw_if_materialized(&self) -> Option<sys::nvrtcProgram> {
1167        let handle = unsafe { *self.handle.get() };
1168        if handle.is_null() { None } else { Some(handle) }
1169    }
1170
1171    /// Transfers ownership of the raw NVRTC program handle to the caller.
1172    ///
1173    /// If the program has not been materialized yet, this creates the raw
1174    /// handle first. The caller becomes responsible for destroying the returned
1175    /// handle with `nvrtcDestroyProgram`.
1176    pub fn into_raw(self) -> Result<sys::nvrtcProgram> {
1177        let handle = self.handle()?;
1178        std::mem::forget(self);
1179        Ok(handle)
1180    }
1181
1182    fn compile_raw(&self, options: &[&str]) -> Result<()> {
1183        let options = options
1184            .iter()
1185            .map(|option| CString::new(*option))
1186            .collect::<result::Result<Vec<_>, _>>()?;
1187        let option_ptrs = options
1188            .iter()
1189            .map(|value| value.as_ptr())
1190            .collect::<Vec<_>>();
1191
1192        unsafe {
1193            try_nvrtc!(sys::nvrtcCompileProgram(
1194                self.handle()?,
1195                option_ptrs.len() as _,
1196                if option_ptrs.is_empty() {
1197                    ptr::null()
1198                } else {
1199                    option_ptrs.as_ptr()
1200                },
1201            ))
1202        }
1203    }
1204
1205    fn bytes(
1206        &self,
1207        get_size: unsafe extern "C" fn(sys::nvrtcProgram, *mut sys::size_t) -> sys::nvrtcResult,
1208        get_data: unsafe extern "C" fn(sys::nvrtcProgram, *mut i8) -> sys::nvrtcResult,
1209    ) -> Result<Vec<u8>> {
1210        let mut size = 0;
1211        unsafe {
1212            try_nvrtc!(get_size(self.handle()?, &raw mut size))?;
1213        }
1214
1215        let mut bytes = vec![0u8; size as usize];
1216        if bytes.is_empty() {
1217            return Ok(bytes);
1218        }
1219
1220        unsafe {
1221            try_nvrtc!(get_data(self.handle()?, bytes.as_mut_ptr().cast()))?;
1222        }
1223        Ok(bytes)
1224    }
1225
1226    fn handle(&self) -> Result<sys::nvrtcProgram> {
1227        unsafe {
1228            let handle = self.handle.get();
1229            if (*handle).is_null() {
1230                *handle = self.create_handle()?;
1231            }
1232            Ok(*handle)
1233        }
1234    }
1235
1236    fn create_handle(&self) -> Result<sys::nvrtcProgram> {
1237        let source = CString::new(self.source.as_str())?;
1238        let name = self.name.as_deref().map(CString::new).transpose()?;
1239        let header_sources = self
1240            .headers
1241            .iter()
1242            .map(|header| CString::new(header.source.as_str()))
1243            .collect::<result::Result<Vec<_>, _>>()?;
1244        let include_names = self
1245            .headers
1246            .iter()
1247            .map(|header| CString::new(header.include_name.as_str()))
1248            .collect::<result::Result<Vec<_>, _>>()?;
1249        let header_ptrs = header_sources
1250            .iter()
1251            .map(|value| value.as_ptr())
1252            .collect::<Vec<_>>();
1253        let include_name_ptrs = include_names
1254            .iter()
1255            .map(|value| value.as_ptr())
1256            .collect::<Vec<_>>();
1257
1258        let mut handle = ptr::null_mut();
1259        unsafe {
1260            try_nvrtc!(sys::nvrtcCreateProgram(
1261                &raw mut handle,
1262                source.as_ptr(),
1263                name.as_ref().map_or(ptr::null(), |value| value.as_ptr()),
1264                self.headers.len() as _,
1265                if header_ptrs.is_empty() {
1266                    ptr::null()
1267                } else {
1268                    header_ptrs.as_ptr()
1269                },
1270                if include_name_ptrs.is_empty() {
1271                    ptr::null()
1272                } else {
1273                    include_name_ptrs.as_ptr()
1274                },
1275            ))?;
1276        }
1277
1278        Ok(handle)
1279    }
1280}
1281
1282impl Drop for Program {
1283    fn drop(&mut self) {
1284        unsafe {
1285            let handle = self.handle.get();
1286            if !(*handle).is_null() {
1287                let _ = sys::nvrtcDestroyProgram(handle);
1288            }
1289        }
1290    }
1291}
1292
1293#[non_exhaustive]
1294pub enum CompilationArtifact {
1295    Ptx(ModuleImage<'static>),
1296    Cubin(ModuleImage<'static>),
1297    LtoIr(ModuleImage<'static>),
1298    OptixIr(ModuleImage<'static>),
1299}
1300
1301impl CompilationArtifact {
1302    pub fn image(&self) -> &ModuleImage<'static> {
1303        match self {
1304            Self::Ptx(image) | Self::Cubin(image) | Self::LtoIr(image) | Self::OptixIr(image) => {
1305                image
1306            }
1307        }
1308    }
1309
1310    pub fn into_image(self) -> ModuleImage<'static> {
1311        match self {
1312            Self::Ptx(image) | Self::Cubin(image) | Self::LtoIr(image) | Self::OptixIr(image) => {
1313                image
1314            }
1315        }
1316    }
1317}
1318
1319#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1320#[non_exhaustive]
1321pub enum OutputKind {
1322    Ptx,
1323    Cubin,
1324    LtoIr,
1325    OptixIr,
1326}
1327
1328impl Program {
1329    pub fn artifact(&self, kind: OutputKind) -> Result<CompilationArtifact> {
1330        match kind {
1331            OutputKind::Ptx => {
1332                let image = self.ptx_image()?;
1333                if image.as_bytes().is_empty() {
1334                    return Err(Error::InvalidValue);
1335                }
1336                Ok(CompilationArtifact::Ptx(image))
1337            }
1338            OutputKind::Cubin => {
1339                let image = self.cubin_image()?;
1340                if image.as_bytes().is_empty() {
1341                    return Err(Error::InvalidValue);
1342                }
1343                Ok(CompilationArtifact::Cubin(image))
1344            }
1345            OutputKind::LtoIr => {
1346                let image = self.lto_ir_image()?;
1347                if image.as_bytes().is_empty() {
1348                    return Err(Error::InvalidValue);
1349                }
1350                Ok(CompilationArtifact::LtoIr(image))
1351            }
1352            OutputKind::OptixIr => {
1353                let image = self.optix_ir_image()?;
1354                if image.as_bytes().is_empty() {
1355                    return Err(Error::InvalidValue);
1356                }
1357                Ok(CompilationArtifact::OptixIr(image))
1358            }
1359        }
1360    }
1361}
1362
1363/// Returns the CUDA Runtime Compilation version.
1364///
1365/// # Errors
1366///
1367/// Returns an error if NVRTC cannot report its version.
1368pub fn version() -> Result<Version> {
1369    let mut major = 0;
1370    let mut minor = 0;
1371    unsafe {
1372        try_nvrtc!(sys::nvrtcVersion(&raw mut major, &raw mut minor))?;
1373    }
1374    Ok(Version { major, minor })
1375}
1376
1377pub fn supported_architectures() -> Result<Vec<i32>> {
1378    let mut count = 0;
1379    unsafe {
1380        try_nvrtc!(sys::nvrtcGetNumSupportedArchs(&raw mut count))?;
1381    }
1382
1383    let mut architectures = vec![0; count as usize];
1384    if architectures.is_empty() {
1385        return Ok(Vec::new());
1386    }
1387
1388    unsafe {
1389        try_nvrtc!(sys::nvrtcGetSupportedArchs(architectures.as_mut_ptr()))?;
1390    }
1391    Ok(architectures)
1392}
1393
1394/// Returns the current size of the PCH heap.
1395///
1396/// # Errors
1397///
1398/// Returns an error if NVRTC cannot report the PCH heap size.
1399pub fn pch_heap_size() -> Result<usize> {
1400    let mut size = 0;
1401    unsafe {
1402        try_nvrtc!(sys::nvrtcGetPCHHeapSize(&raw mut size))?;
1403    }
1404    Ok(size as usize)
1405}
1406
1407/// Sets the size of the PCH heap.
1408///
1409/// The requested size may be rounded up to a platform-dependent alignment, such as the page size.
1410/// If the PCH heap has already been allocated, NVRTC frees it and allocates a new heap.
1411///
1412/// # Errors
1413///
1414/// Returns an error if NVRTC rejects the requested PCH heap size.
1415pub fn set_pch_heap_size(size: usize) -> Result<()> {
1416    unsafe { try_nvrtc!(sys::nvrtcSetPCHHeapSize(size as _)) }
1417}
1418
1419unsafe extern "C" fn noop_compile_callback(
1420    payload: *mut std::ffi::c_void,
1421    reserved: *mut std::ffi::c_void,
1422) -> i32 {
1423    let _ = payload;
1424    let _ = reserved;
1425    0
1426}
1427
1428unsafe extern "C" fn cancel_if_requested_callback(
1429    payload: *mut std::ffi::c_void,
1430    reserved: *mut std::ffi::c_void,
1431) -> i32 {
1432    let _ = reserved;
1433    if payload.is_null() {
1434        return 0;
1435    }
1436
1437    let cancel = unsafe { &*payload.cast::<AtomicBool>() };
1438    i32::from(cancel.load(Ordering::Relaxed))
1439}
1440
1441/// Clears the NVRTC flow callback by installing a no-op callback.
1442///
1443/// Use this after a compile attempt that installed a cancellation callback.
1444///
1445/// # Errors
1446///
1447/// Returns an error if the program handle is invalid or if NVRTC rejects the
1448/// callback update.
1449pub fn clear_flow_callback(program: &Program) -> Result<()> {
1450    unsafe {
1451        try_nvrtc!(sys::nvrtcSetFlowCallback(
1452            program.handle()?,
1453            Some(noop_compile_callback),
1454            ptr::null_mut(),
1455        ))
1456    }
1457}
1458
1459fn bool_flag(value: bool) -> &'static str {
1460    if value { "true" } else { "false" }
1461}
1462
1463fn join_display(values: &[impl Display]) -> String {
1464    values
1465        .iter()
1466        .map(ToString::to_string)
1467        .collect::<Vec<_>>()
1468        .join(",")
1469}
1470
1471fn join_numbers(values: &[i32]) -> String {
1472    values
1473        .iter()
1474        .map(ToString::to_string)
1475        .collect::<Vec<_>>()
1476        .join(",")
1477}
1478
1479fn bytes_to_string(mut bytes: Vec<u8>) -> String {
1480    while bytes.last() == Some(&0) {
1481        bytes.pop();
1482    }
1483    String::from_utf8_lossy(&bytes).into_owned()
1484}
1485
1486#[cfg(all(test, feature = "testing"))]
1487mod tests {
1488    use super::*;
1489    use crate::{
1490        device::Device, error::Result, memory::DeviceMemory, module::LaunchConfig, testing,
1491    };
1492
1493    fn current_device_sm_architecture() -> Result<GpuArchitecture> {
1494        let properties = Device::current()?.properties()?;
1495        Ok(match (properties.major, properties.minor) {
1496            (7, 5) => GpuArchitecture::Sm75,
1497            (8, 0) => GpuArchitecture::Sm80,
1498            (8, 6) => GpuArchitecture::Sm86,
1499            (8, 7) => GpuArchitecture::Sm87,
1500            (8, 9) => GpuArchitecture::Sm89,
1501            (9, 0) => GpuArchitecture::Sm90,
1502            (10, 0) => GpuArchitecture::Sm100,
1503            (10, 1) => GpuArchitecture::Sm101,
1504            (10, 3) => GpuArchitecture::Sm103,
1505            (12, 0) => GpuArchitecture::Sm120,
1506            (12, 1) => GpuArchitecture::Sm121,
1507            (major, minor) => panic!("unsupported device architecture sm_{major}{minor}"),
1508        })
1509    }
1510
1511    #[test]
1512    fn version_is_available() {
1513        let version = version().unwrap();
1514        assert_ne!(version.major, 0);
1515    }
1516
1517    #[test]
1518    fn supported_architectures_are_sorted() {
1519        let architectures = supported_architectures().unwrap();
1520        assert!(!architectures.is_empty());
1521        assert!(
1522            architectures
1523                .windows(2)
1524                .all(|window| window[0] <= window[1])
1525        );
1526    }
1527
1528    #[test]
1529    fn compile_options_build_expected_arguments() {
1530        let arguments = CompileOptions::new()
1531            .gpu_architecture(GpuArchitecture::Compute80)
1532            .device_debug(true)
1533            .generate_line_info(true)
1534            .define_macro(MacroDefinition::WithValue {
1535                name: "FOO",
1536                value: "42",
1537            })
1538            .include_path("include")
1539            .cpp_dialect(CppDialect::Cpp20)
1540            .warning_as_error(WarningAsErrorKind::Reorder)
1541            .diag_suppress(177)
1542            .raw_option("--custom-flag")
1543            .as_arguments();
1544
1545        assert_eq!(
1546            arguments,
1547            vec![
1548                "--gpu-architecture=compute_80",
1549                "--device-debug",
1550                "--generate-line-info",
1551                "--define-macro=FOO=42",
1552                "--include-path=include",
1553                "--std=c++20",
1554                "--warning-as-error=reorder",
1555                "--diag-suppress=177",
1556                "--custom-flag",
1557            ]
1558        );
1559    }
1560
1561    #[test]
1562    fn compiles_to_ptx() {
1563        let program = Program::from_source(
1564            r#"
1565            extern "C" __global__ void saxpy(float a, const float* x, const float* y, float* out, size_t n) {
1566                size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
1567                if (tid < n) {
1568                    out[tid] = a * x[tid] + y[tid];
1569                }
1570            }
1571            "#,
1572        )
1573        .with_name("saxpy.cu");
1574        let options = CompileOptions::new()
1575            .gpu_architecture(GpuArchitecture::Compute80)
1576            .generate_line_info(true);
1577        program.compile_with_options(&options).unwrap();
1578
1579        let ptx = program.ptx_string().unwrap();
1580        assert!(ptx.contains(".visible .entry saxpy"));
1581    }
1582
1583    #[test]
1584    fn compile_with_cancel_flag_succeeds_when_not_cancelled() {
1585        let (_lock, _ctx) = testing::bootstrap().unwrap();
1586        let cancel = AtomicBool::new(false);
1587        let program = Program::from_source(
1588            r#"
1589            extern "C" __global__ void noop() {}
1590            "#,
1591        )
1592        .with_name("noop.cu");
1593        let options = CompileOptions::new().gpu_architecture(GpuArchitecture::Compute80);
1594
1595        program
1596            .compile_with_options_and_cancel_flag(&options, &cancel)
1597            .unwrap();
1598        assert!(program.ptx_string().unwrap().contains("noop"));
1599    }
1600
1601    #[test]
1602    fn clear_flow_callback_is_allowed_before_compilation() {
1603        let program =
1604            Program::from_source("extern \"C\" __global__ void noop() {}").with_name("noop.cu");
1605        clear_flow_callback(&program).unwrap();
1606    }
1607
1608    #[test]
1609    fn cubin_artifact_loads_as_module() {
1610        let (_lock, ctx) = testing::bootstrap().unwrap();
1611        let program = Program::from_source(
1612            r#"
1613            extern "C" __global__ void saxpy(float a, const float* x, const float* y, float* out, size_t n) {
1614                size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
1615                if (tid < n) {
1616                    out[tid] = a * x[tid] + y[tid];
1617                }
1618            }
1619            "#,
1620        )
1621        .with_name("saxpy_module.cu");
1622        let architecture = current_device_sm_architecture().unwrap();
1623        let options = CompileOptions::new().gpu_architecture(architecture);
1624        program.compile_with_options(&options).unwrap();
1625
1626        let module = ctx.load_nvrtc_module(&program, OutputKind::Cubin).unwrap();
1627        assert!(module.function("saxpy").is_ok());
1628    }
1629
1630    #[test]
1631    fn cubin_artifact_loads_as_module_with_jit_options() {
1632        let (_lock, ctx) = testing::bootstrap().unwrap();
1633        let program = Program::from_source(
1634            r#"
1635            extern "C" __global__ void noop() {}
1636            "#,
1637        )
1638        .with_name("noop_module_jit.cu");
1639        let architecture = current_device_sm_architecture().unwrap();
1640        let options = CompileOptions::new().gpu_architecture(architecture);
1641        program.compile_with_options(&options).unwrap();
1642
1643        let mut info_log = [0u8; 1024];
1644        let mut error_log = [0u8; 1024];
1645        let jit_options = crate::jit::JitOptions::default()
1646            .with_generate_line_info(true)
1647            .with_info_log(&mut info_log)
1648            .with_error_log(&mut error_log);
1649
1650        let module = ctx
1651            .load_nvrtc_module_with_options(&program, OutputKind::Cubin, jit_options)
1652            .unwrap();
1653        assert!(module.function("noop").is_ok());
1654    }
1655
1656    #[test]
1657    fn compiles_loads_launches_and_reads_back_results() {
1658        let (_lock, ctx) = testing::bootstrap().unwrap();
1659
1660        let input = vec![1.0f32, 2.0, 3.5, -4.0, 8.25];
1661        let mut output = vec![0.0f32; input.len()];
1662        let scalar = 2.5f32;
1663        let input_device = DeviceMemory::from_slice(&input).unwrap();
1664        let output_device = DeviceMemory::<f32>::zeroes(output.len()).unwrap();
1665        let length = input.len();
1666
1667        let program = Program::from_source(
1668            r#"
1669            extern "C" __global__ void scale_add(const float* input, float* output, float alpha, size_t len) {
1670                size_t i = blockIdx.x * blockDim.x + threadIdx.x;
1671                if (i < len) {
1672                    output[i] = input[i] * alpha + 1.0f;
1673                }
1674            }
1675            "#,
1676        )
1677        .with_name("scale_add.cu");
1678        let architecture = current_device_sm_architecture().unwrap();
1679        let compile_options = CompileOptions::new().gpu_architecture(architecture);
1680        program.compile_with_options(&compile_options).unwrap();
1681
1682        let module = ctx.load_nvrtc_module(&program, OutputKind::Cubin).unwrap();
1683        let function = module.function("scale_add").unwrap();
1684
1685        let config = LaunchConfig::for_1d_grid(input.len(), 128);
1686        let input_ptr = input_device.as_ptr();
1687        let mut output_ptr = output_device.as_mut_ptr();
1688
1689        function
1690            .launch(&config, (&input_ptr, &mut output_ptr, &scalar, &length))
1691            .unwrap();
1692        output_device.copy_to_host(&mut output).unwrap();
1693
1694        let expected = input
1695            .iter()
1696            .map(|value| value * scalar + 1.0)
1697            .collect::<Vec<_>>();
1698        assert_eq!(output, expected);
1699    }
1700
1701    #[test]
1702    fn cubin_artifact_loads_as_library() {
1703        let (_lock, ctx) = testing::bootstrap().unwrap();
1704        let program = Program::from_source(
1705            r#"
1706            extern "C" __global__ void noop() {}
1707            "#,
1708        )
1709        .with_name("noop_library.cu");
1710        let architecture = current_device_sm_architecture().unwrap();
1711        let options = CompileOptions::new().gpu_architecture(architecture);
1712        program.compile_with_options(&options).unwrap();
1713
1714        let library = ctx.load_nvrtc_library(&program, OutputKind::Cubin).unwrap();
1715        assert!(library.kernel_count().unwrap() >= 1);
1716    }
1717
1718    #[test]
1719    fn lto_ir_artifact_is_available_when_requested() {
1720        let program = Program::from_source(
1721            r#"
1722            extern "C" __global__ void noop() {}
1723            "#,
1724        )
1725        .with_name("noop_lto.cu");
1726        let options = CompileOptions::new().dlink_time_optimization(true);
1727        program.compile_with_options(&options).unwrap();
1728
1729        let artifact = program.artifact(OutputKind::LtoIr).unwrap();
1730        assert!(!artifact.image().as_bytes().is_empty());
1731        let ptx = program.artifact(OutputKind::Ptx).unwrap();
1732        assert!(!ptx.image().as_bytes().is_empty());
1733    }
1734
1735    #[test]
1736    fn optix_ir_artifact_is_available_when_requested() {
1737        let program = Program::from_source(
1738            r#"
1739            extern "C" __global__ void noop() {}
1740            "#,
1741        )
1742        .with_name("noop_optix.cu");
1743        let options = CompileOptions::new().optix_ir(true);
1744        program.compile_with_options(&options).unwrap();
1745
1746        let artifact = program.artifact(OutputKind::OptixIr).unwrap();
1747        assert!(!artifact.image().as_bytes().is_empty());
1748        let ptx = program.artifact(OutputKind::Ptx).unwrap();
1749        assert!(!ptx.image().as_bytes().is_empty());
1750    }
1751
1752    #[test]
1753    fn cubin_artifact_requires_real_architecture() {
1754        let program = Program::from_source(
1755            r#"
1756            extern "C" __global__ void noop() {}
1757            "#,
1758        )
1759        .with_name("noop_cubin.cu");
1760        let options = CompileOptions::new().gpu_architecture(GpuArchitecture::Compute80);
1761        program.compile_with_options(&options).unwrap();
1762
1763        assert!(program.artifact(OutputKind::Cubin).is_err());
1764    }
1765}