Skip to main content

singe_cuda/
nvvm.rs

1use std::{
2    ffi::{CStr, CString},
3    fmt, ptr, result,
4};
5
6use singe_core::impl_enum_display;
7use singe_cuda_sys::nvvm as sys;
8
9use crate::{
10    architecture::GpuArchitecture,
11    error::{Error, Result},
12    module::ModuleImage,
13    try_nvvm,
14};
15
16// TODO: move to a core version type
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct Version {
19    pub major: i32,
20    pub minor: i32,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub struct IrVersion {
25    pub major: i32,
26    pub minor: i32,
27    pub debug_major: i32,
28    pub debug_minor: i32,
29}
30
31/// NVVM result code returned by compiler operations.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33#[non_exhaustive]
34pub enum Status {
35    Success,
36    OutOfMemory,
37    ProgramCreationFailure,
38    IrVersionMismatch,
39    InvalidInput,
40    InvalidProgram,
41    InvalidIr,
42    InvalidOption,
43    NoModuleInProgram,
44    Compilation,
45    Cancelled,
46    Unknown(u32),
47}
48
49impl Status {
50    pub fn description(self) -> String {
51        match sys::nvvmResult::try_from(self.raw()) {
52            Ok(status) => unsafe {
53                let ptr = sys::nvvmGetErrorString(status);
54                if ptr.is_null() {
55                    String::from("unknown nvvm error")
56                } else {
57                    CStr::from_ptr(ptr).to_string_lossy().into_owned()
58                }
59            },
60            Err(_) => String::from("unknown nvvm error"),
61        }
62    }
63
64    pub const fn raw(self) -> u32 {
65        match self {
66            Self::Success => sys::nvvmResult::NVVM_SUCCESS as _,
67            Self::OutOfMemory => sys::nvvmResult::NVVM_ERROR_OUT_OF_MEMORY as _,
68            Self::ProgramCreationFailure => {
69                sys::nvvmResult::NVVM_ERROR_PROGRAM_CREATION_FAILURE as _
70            }
71            Self::IrVersionMismatch => sys::nvvmResult::NVVM_ERROR_IR_VERSION_MISMATCH as _,
72            Self::InvalidInput => sys::nvvmResult::NVVM_ERROR_INVALID_INPUT as _,
73            Self::InvalidProgram => sys::nvvmResult::NVVM_ERROR_INVALID_PROGRAM as _,
74            Self::InvalidIr => sys::nvvmResult::NVVM_ERROR_INVALID_IR as _,
75            Self::InvalidOption => sys::nvvmResult::NVVM_ERROR_INVALID_OPTION as _,
76            Self::NoModuleInProgram => sys::nvvmResult::NVVM_ERROR_NO_MODULE_IN_PROGRAM as _,
77            Self::Compilation => sys::nvvmResult::NVVM_ERROR_COMPILATION as _,
78            Self::Cancelled => sys::nvvmResult::NVVM_ERROR_CANCELLED as _,
79            Self::Unknown(code) => code,
80        }
81    }
82}
83
84impl TryFrom<u32> for Status {
85    type Error = u32;
86
87    fn try_from(code: u32) -> result::Result<Self, u32> {
88        match code {
89            code if code == sys::nvvmResult::NVVM_SUCCESS as u32 => Ok(Self::Success),
90            code if code == sys::nvvmResult::NVVM_ERROR_OUT_OF_MEMORY as u32 => {
91                Ok(Self::OutOfMemory)
92            }
93            code if code == sys::nvvmResult::NVVM_ERROR_PROGRAM_CREATION_FAILURE as u32 => {
94                Ok(Self::ProgramCreationFailure)
95            }
96            code if code == sys::nvvmResult::NVVM_ERROR_IR_VERSION_MISMATCH as u32 => {
97                Ok(Self::IrVersionMismatch)
98            }
99            code if code == sys::nvvmResult::NVVM_ERROR_INVALID_INPUT as u32 => {
100                Ok(Self::InvalidInput)
101            }
102            code if code == sys::nvvmResult::NVVM_ERROR_INVALID_PROGRAM as u32 => {
103                Ok(Self::InvalidProgram)
104            }
105            code if code == sys::nvvmResult::NVVM_ERROR_INVALID_IR as u32 => Ok(Self::InvalidIr),
106            code if code == sys::nvvmResult::NVVM_ERROR_INVALID_OPTION as u32 => {
107                Ok(Self::InvalidOption)
108            }
109            code if code == sys::nvvmResult::NVVM_ERROR_NO_MODULE_IN_PROGRAM as u32 => {
110                Ok(Self::NoModuleInProgram)
111            }
112            code if code == sys::nvvmResult::NVVM_ERROR_COMPILATION as u32 => Ok(Self::Compilation),
113            code if code == sys::nvvmResult::NVVM_ERROR_CANCELLED as u32 => Ok(Self::Cancelled),
114            code => Err(code),
115        }
116    }
117}
118
119impl From<sys::nvvmResult> for Status {
120    fn from(status: sys::nvvmResult) -> Self {
121        Self::try_from(status as u32).unwrap_or_else(Self::Unknown)
122    }
123}
124
125impl TryFrom<Status> for sys::nvvmResult {
126    type Error = Status;
127
128    fn try_from(status: Status) -> result::Result<Self, Status> {
129        match status {
130            Status::Success => Ok(Self::NVVM_SUCCESS),
131            Status::OutOfMemory => Ok(Self::NVVM_ERROR_OUT_OF_MEMORY),
132            Status::ProgramCreationFailure => Ok(Self::NVVM_ERROR_PROGRAM_CREATION_FAILURE),
133            Status::IrVersionMismatch => Ok(Self::NVVM_ERROR_IR_VERSION_MISMATCH),
134            Status::InvalidInput => Ok(Self::NVVM_ERROR_INVALID_INPUT),
135            Status::InvalidProgram => Ok(Self::NVVM_ERROR_INVALID_PROGRAM),
136            Status::InvalidIr => Ok(Self::NVVM_ERROR_INVALID_IR),
137            Status::InvalidOption => Ok(Self::NVVM_ERROR_INVALID_OPTION),
138            Status::NoModuleInProgram => Ok(Self::NVVM_ERROR_NO_MODULE_IN_PROGRAM),
139            Status::Compilation => Ok(Self::NVVM_ERROR_COMPILATION),
140            Status::Cancelled => Ok(Self::NVVM_ERROR_CANCELLED),
141            Status::Unknown(_) => Err(status),
142        }
143    }
144}
145
146impl fmt::Display for Status {
147    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148        match self {
149            Self::Success => f.write_str("NVVM_SUCCESS"),
150            Self::OutOfMemory => f.write_str("NVVM_ERROR_OUT_OF_MEMORY"),
151            Self::ProgramCreationFailure => f.write_str("NVVM_ERROR_PROGRAM_CREATION_FAILURE"),
152            Self::IrVersionMismatch => f.write_str("NVVM_ERROR_IR_VERSION_MISMATCH"),
153            Self::InvalidInput => f.write_str("NVVM_ERROR_INVALID_INPUT"),
154            Self::InvalidProgram => f.write_str("NVVM_ERROR_INVALID_PROGRAM"),
155            Self::InvalidIr => f.write_str("NVVM_ERROR_INVALID_IR"),
156            Self::InvalidOption => f.write_str("NVVM_ERROR_INVALID_OPTION"),
157            Self::NoModuleInProgram => f.write_str("NVVM_ERROR_NO_MODULE_IN_PROGRAM"),
158            Self::Compilation => f.write_str("NVVM_ERROR_COMPILATION"),
159            Self::Cancelled => f.write_str("NVVM_ERROR_CANCELLED"),
160            Self::Unknown(code) => write!(f, "UNKNOWN_NVVM_STATUS({code})"),
161        }
162    }
163}
164
165#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
166#[non_exhaustive]
167pub enum OptimizationLevel {
168    Zero,
169    Three,
170}
171
172impl_enum_display!(OptimizationLevel, {
173    Self::Zero => "0",
174    Self::Three => "3",
175});
176
177#[derive(Debug, Clone, Default)]
178pub struct CompileOptions<'a> {
179    pub device_debug: bool,
180    pub optimization_level: Option<OptimizationLevel>,
181    pub gpu_architecture: Option<GpuArchitecture>,
182    pub flush_to_zero: Option<bool>,
183    pub precise_square_root: Option<bool>,
184    pub precise_division: Option<bool>,
185    pub fma: Option<bool>,
186    pub jump_table_density: Option<u8>,
187    pub generate_lto: bool,
188    pub raw_options: Vec<&'a str>,
189}
190
191impl<'a> CompileOptions<'a> {
192    pub const fn new() -> Self {
193        Self {
194            device_debug: false,
195            optimization_level: None,
196            gpu_architecture: None,
197            flush_to_zero: None,
198            precise_square_root: None,
199            precise_division: None,
200            fma: None,
201            jump_table_density: None,
202            generate_lto: false,
203            raw_options: Vec::new(),
204        }
205    }
206
207    pub fn device_debug(mut self, value: bool) -> Self {
208        self.device_debug = value;
209        self
210    }
211
212    pub fn optimization_level(mut self, value: OptimizationLevel) -> Self {
213        self.optimization_level = Some(value);
214        self
215    }
216
217    pub fn gpu_architecture(mut self, value: GpuArchitecture) -> Self {
218        self.gpu_architecture = Some(value);
219        self
220    }
221
222    pub fn flush_to_zero(mut self, value: bool) -> Self {
223        self.flush_to_zero = Some(value);
224        self
225    }
226
227    pub fn precise_square_root(mut self, value: bool) -> Self {
228        self.precise_square_root = Some(value);
229        self
230    }
231
232    pub fn precise_division(mut self, value: bool) -> Self {
233        self.precise_division = Some(value);
234        self
235    }
236
237    pub fn fma(mut self, value: bool) -> Self {
238        self.fma = Some(value);
239        self
240    }
241
242    pub fn jump_table_density(mut self, value: u8) -> Self {
243        self.jump_table_density = Some(value.min(101));
244        self
245    }
246
247    pub fn generate_lto(mut self, value: bool) -> Self {
248        self.generate_lto = value;
249        self
250    }
251
252    pub fn raw_option(mut self, value: &'a str) -> Self {
253        self.raw_options.push(value);
254        self
255    }
256
257    pub fn as_arguments(&self) -> Vec<String> {
258        let mut arguments = Vec::new();
259        if self.device_debug {
260            arguments.push(String::from("-g"));
261        }
262        if let Some(value) = self.optimization_level {
263            arguments.push(format!("-opt={value}"));
264        }
265        if let Some(value) = self.gpu_architecture {
266            arguments.push(format!("-arch={value}"));
267        }
268        if let Some(value) = self.flush_to_zero {
269            arguments.push(format!("-ftz={}", flag_bit(value)));
270        }
271        if let Some(value) = self.precise_square_root {
272            arguments.push(format!("-prec-sqrt={}", flag_bit(value)));
273        }
274        if let Some(value) = self.precise_division {
275            arguments.push(format!("-prec-div={}", flag_bit(value)));
276        }
277        if let Some(value) = self.fma {
278            arguments.push(format!("-fma={}", flag_bit(value)));
279        }
280        if let Some(value) = self.jump_table_density {
281            arguments.push(format!("-jump-table-density={value}"));
282        }
283        if self.generate_lto {
284            arguments.push(String::from("-gen-lto"));
285        }
286        arguments.extend(self.raw_options.iter().map(|value| (*value).to_string()));
287        arguments
288    }
289
290    fn validate(&self) -> Result<()> {
291        if let Some(architecture) = self.gpu_architecture
292            && !architecture.is_virtual()
293        {
294            return Err(Error::InvalidValue);
295        }
296        Ok(())
297    }
298}
299
300#[derive(Debug)]
301pub struct Module<'a> {
302    pub ir: &'a [u8],
303    pub name: &'a str,
304}
305
306#[derive(Debug)]
307pub struct Program {
308    handle: sys::nvvmProgram,
309}
310
311impl Program {
312    pub fn create() -> Result<Self> {
313        let mut handle = ptr::null_mut();
314        unsafe {
315            try_nvvm!(sys::nvvmCreateProgram(&raw mut handle))?;
316        }
317        if handle.is_null() {
318            return Err(Error::NullHandle);
319        }
320        Ok(Self { handle })
321    }
322
323    /// Takes ownership of a raw NVVM program handle.
324    ///
325    /// # Safety
326    ///
327    /// `handle` must be a valid `nvvmProgram` that is not owned by any other
328    /// wrapper. The returned wrapper destroys it with `nvvmDestroyProgram`.
329    pub unsafe fn from_raw(handle: sys::nvvmProgram) -> Result<Self> {
330        if handle.is_null() {
331            return Err(Error::NullHandle);
332        }
333
334        Ok(Self { handle })
335    }
336
337    pub fn add_module(&mut self, module: Module<'_>) -> Result<()> {
338        self.add_module_raw(module, false)
339    }
340
341    pub fn lazy_add_module(&mut self, module: Module<'_>) -> Result<()> {
342        self.add_module_raw(module, true)
343    }
344
345    pub fn compile(&self, options: &[&str]) -> Result<()> {
346        self.compile_raw(sys::nvvmCompileProgram, options)
347    }
348
349    pub fn compile_with_options(&self, options: &CompileOptions<'_>) -> Result<()> {
350        options.validate()?;
351        let arguments = options.as_arguments();
352        let argument_refs = arguments.iter().map(String::as_str).collect::<Vec<_>>();
353        self.compile(&argument_refs)
354    }
355
356    pub fn verify(&self, options: &[&str]) -> Result<()> {
357        self.compile_raw(sys::nvvmVerifyProgram, options)
358    }
359
360    pub fn verify_with_options(&self, options: &CompileOptions<'_>) -> Result<()> {
361        options.validate()?;
362        let arguments = options.as_arguments();
363        let argument_refs = arguments.iter().map(String::as_str).collect::<Vec<_>>();
364        self.verify(&argument_refs)
365    }
366
367    pub fn compiled_result(&self) -> Result<Vec<u8>> {
368        self.bytes(sys::nvvmGetCompiledResultSize, sys::nvvmGetCompiledResult)
369    }
370
371    pub fn compiled_image(&self) -> Result<ModuleImage<'static>> {
372        Ok(ModuleImage::from_vec(self.compiled_result()?))
373    }
374
375    pub fn compiled_string(&self) -> Result<String> {
376        Ok(bytes_to_string(self.compiled_result()?))
377    }
378
379    pub fn log(&self) -> Result<String> {
380        Ok(bytes_to_string(self.bytes(
381            sys::nvvmGetProgramLogSize,
382            sys::nvvmGetProgramLog,
383        )?))
384    }
385
386    pub const fn as_raw(&self) -> sys::nvvmProgram {
387        self.handle
388    }
389
390    /// Transfers ownership of the raw NVVM program handle to the caller.
391    ///
392    /// The caller becomes responsible for destroying it with
393    /// `nvvmDestroyProgram`.
394    pub fn into_raw(self) -> sys::nvvmProgram {
395        let handle = self.handle;
396        std::mem::forget(self);
397        handle
398    }
399
400    fn add_module_raw(&mut self, module: Module<'_>, lazy: bool) -> Result<()> {
401        if module.ir.is_empty() {
402            return Err(Error::InvalidValue);
403        }
404        let name = CString::new(module.name)?;
405        let function = if lazy {
406            sys::nvvmLazyAddModuleToProgram
407        } else {
408            sys::nvvmAddModuleToProgram
409        };
410
411        unsafe {
412            try_nvvm!(function(
413                self.handle,
414                module.ir.as_ptr().cast(),
415                module.ir.len() as _,
416                name.as_ptr(),
417            ))
418        }
419    }
420
421    fn compile_raw(
422        &self,
423        function: unsafe extern "C" fn(sys::nvvmProgram, i32, *mut *const i8) -> sys::nvvmResult,
424        options: &[&str],
425    ) -> Result<()> {
426        let options = options
427            .iter()
428            .map(|option| CString::new(*option))
429            .collect::<result::Result<Vec<_>, _>>()?;
430        let mut option_ptrs = options
431            .iter()
432            .map(|value| value.as_ptr())
433            .collect::<Vec<_>>();
434
435        unsafe {
436            try_nvvm!(function(
437                self.handle,
438                option_ptrs.len() as _,
439                if option_ptrs.is_empty() {
440                    ptr::null_mut()
441                } else {
442                    option_ptrs.as_mut_ptr()
443                },
444            ))
445        }
446    }
447
448    fn bytes(
449        &self,
450        get_size: unsafe extern "C" fn(sys::nvvmProgram, *mut u64) -> sys::nvvmResult,
451        get_data: unsafe extern "C" fn(sys::nvvmProgram, *mut i8) -> sys::nvvmResult,
452    ) -> Result<Vec<u8>> {
453        let mut size = 0;
454        unsafe {
455            try_nvvm!(get_size(self.handle, &raw mut size))?;
456        }
457
458        let mut bytes = vec![0u8; size as usize];
459        if bytes.is_empty() {
460            return Ok(bytes);
461        }
462
463        unsafe {
464            try_nvvm!(get_data(self.handle, bytes.as_mut_ptr().cast()))?;
465        }
466        Ok(bytes)
467    }
468}
469
470impl Drop for Program {
471    fn drop(&mut self) {
472        unsafe {
473            if !self.handle.is_null() {
474                let _ = sys::nvvmDestroyProgram(&raw mut self.handle);
475            }
476        }
477    }
478}
479
480pub fn version() -> Result<Version> {
481    let mut major = 0;
482    let mut minor = 0;
483    unsafe {
484        try_nvvm!(sys::nvvmVersion(&raw mut major, &raw mut minor))?;
485    }
486    Ok(Version { major, minor })
487}
488
489pub fn ir_version() -> Result<IrVersion> {
490    let mut major = 0;
491    let mut minor = 0;
492    let mut debug_major = 0;
493    let mut debug_minor = 0;
494    unsafe {
495        try_nvvm!(sys::nvvmIRVersion(
496            &raw mut major,
497            &raw mut minor,
498            &raw mut debug_major,
499            &raw mut debug_minor,
500        ))?;
501    }
502    Ok(IrVersion {
503        major,
504        minor,
505        debug_major,
506        debug_minor,
507    })
508}
509
510pub fn llvm_version(architecture: GpuArchitecture) -> Result<i32> {
511    if !architecture.is_virtual() {
512        return Err(Error::InvalidValue);
513    }
514    llvm_version_for_architecture(&architecture.to_string())
515}
516
517pub fn llvm_version_for_architecture(architecture: &str) -> Result<i32> {
518    let architecture = CString::new(architecture)?;
519    let mut major = 0;
520    unsafe {
521        try_nvvm!(sys::nvvmLLVMVersion(architecture.as_ptr(), &raw mut major,))?;
522    }
523    Ok(major)
524}
525
526fn flag_bit(value: bool) -> u8 {
527    u8::from(value)
528}
529
530fn bytes_to_string(mut bytes: Vec<u8>) -> String {
531    if bytes.last() == Some(&0) {
532        bytes.pop();
533    }
534    String::from_utf8_lossy(&bytes).into_owned()
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540
541    #[test]
542    fn compile_options_build_expected_arguments() {
543        let options = CompileOptions::new()
544            .device_debug(true)
545            .optimization_level(OptimizationLevel::Zero)
546            .gpu_architecture(GpuArchitecture::Compute90)
547            .flush_to_zero(true)
548            .precise_square_root(false)
549            .precise_division(true)
550            .fma(false)
551            .jump_table_density(200)
552            .generate_lto(true)
553            .raw_option("-custom");
554
555        assert_eq!(
556            options.as_arguments(),
557            vec![
558                "-g",
559                "-opt=0",
560                "-arch=compute_90",
561                "-ftz=1",
562                "-prec-sqrt=0",
563                "-prec-div=1",
564                "-fma=0",
565                "-jump-table-density=101",
566                "-gen-lto",
567                "-custom",
568            ]
569        );
570    }
571
572    #[test]
573    fn version_queries_are_available() {
574        let version = version().unwrap();
575        assert!(version.major > 0);
576
577        let ir_version = ir_version().unwrap();
578        assert!(ir_version.major > 0);
579
580        let llvm_version = llvm_version(GpuArchitecture::Compute90).unwrap();
581        assert!(llvm_version > 0);
582    }
583
584    #[test]
585    fn real_architecture_is_rejected_for_nvvm_options() {
586        let options = CompileOptions::new().gpu_architecture(GpuArchitecture::Sm90);
587        assert!(matches!(options.validate(), Err(Error::InvalidValue)));
588        assert!(matches!(
589            llvm_version(GpuArchitecture::Sm90),
590            Err(Error::InvalidValue)
591        ));
592    }
593}