Skip to main content

singe_cuda/
jit.rs

1#![allow(deprecated)]
2
3use std::ptr;
4
5use num_enum::{IntoPrimitive, TryFromPrimitive};
6use singe_cuda_sys::{driver, runtime};
7
8use singe_core::{impl_enum_conversion, impl_enum_display};
9
10#[derive(Debug, Default)]
11pub struct JitOptions<'a> {
12    pub max_registers: Option<u32>,
13    pub threads_per_block: Option<u32>,
14    pub wall_time: Option<&'a mut f32>,
15    pub info_log_buffer: Option<&'a mut [u8]>,
16    pub error_log_buffer: Option<&'a mut [u8]>,
17    pub optimization_level: Option<u32>,
18    pub target_from_cuda_context: Option<()>,
19    pub target: Option<JitTarget>,
20    pub fallback_strategy: Option<JitFallback>,
21    pub generate_debug_info: Option<bool>,
22    pub log_verbose: Option<bool>,
23    pub generate_line_info: Option<bool>,
24    pub cache_mode: Option<JitCacheMode>,
25}
26
27pub struct JitOptionsArtifact {
28    pub names: Vec<driver::CUjit_option>,
29    pub values: Vec<*mut ()>,
30
31    // Storage for values needing stable pointers for the FFI call.
32    storage_target: Option<u32>,
33    storage_fallback: Option<u32>,
34    storage_debug_info: Option<i32>,
35    storage_log_verbose: Option<i32>,
36    storage_line_info: Option<i32>,
37    storage_cache_mode: Option<u32>,
38
39    // For buffers, the driver expects the buffer pointer directly plus mutable size storage.
40    storage_info_log_ptr: Option<*mut u8>,
41    storage_info_log_size: Option<u32>,
42    storage_error_log_ptr: Option<*mut u8>,
43    storage_error_log_size: Option<u32>,
44    storage_max_registers: Option<u32>,
45    storage_threads_per_block: Option<u32>,
46    storage_optimization_level: Option<u32>,
47}
48
49impl<'a> JitOptions<'a> {
50    pub fn with_max_registers(mut self, value: u32) -> Self {
51        self.max_registers = Some(value);
52        self
53    }
54
55    pub fn with_threads_per_block(mut self, value: u32) -> Self {
56        self.threads_per_block = Some(value);
57        self
58    }
59
60    pub fn with_wall_time(mut self, value: &'a mut f32) -> Self {
61        self.wall_time = Some(value);
62        self
63    }
64
65    pub fn with_info_log(mut self, buffer: &'a mut [u8]) -> Self {
66        self.info_log_buffer = Some(buffer);
67        self
68    }
69
70    pub fn with_error_log(mut self, buffer: &'a mut [u8]) -> Self {
71        self.error_log_buffer = Some(buffer);
72        self
73    }
74
75    pub fn with_optimization_level(mut self, level: u32) -> Self {
76        // Clamp to 0-4 as per docs.
77        self.optimization_level = Some(level.min(4));
78        self
79    }
80
81    pub const fn with_target_from_cuda_context(mut self) -> Self {
82        self.target_from_cuda_context = Some(());
83        self
84    }
85
86    pub fn with_target(mut self, target: JitTarget) -> Self {
87        self.target = Some(target);
88        self
89    }
90
91    pub fn with_fallback_strategy(mut self, strategy: JitFallback) -> Self {
92        self.fallback_strategy = Some(strategy);
93        self
94    }
95
96    pub fn with_generate_debug_info(mut self, enable: bool) -> Self {
97        self.generate_debug_info = Some(enable);
98        self
99    }
100
101    pub fn with_log_verbose(mut self, enable: bool) -> Self {
102        self.log_verbose = Some(enable);
103        self
104    }
105
106    pub fn with_generate_line_info(mut self, enable: bool) -> Self {
107        self.generate_line_info = Some(enable);
108        self
109    }
110
111    pub fn with_cache_mode(mut self, mode: JitCacheMode) -> Self {
112        self.cache_mode = Some(mode);
113        self
114    }
115
116    /// Prepares the option and value arrays for the FFI call.
117    /// Populates the internal storage fields to keep pointers stable.
118    ///
119    /// # Safety
120    ///
121    /// The caller must ensure that this [`JitOptions`] instance outlives the FFI call.
122    pub fn build(&mut self) -> JitOptionsArtifact {
123        let mut artifact = JitOptionsArtifact {
124            names: Vec::new(),
125            values: Vec::new(),
126            storage_target: None,
127            storage_fallback: None,
128            storage_debug_info: None,
129            storage_log_verbose: None,
130            storage_line_info: None,
131            storage_cache_mode: None,
132            storage_info_log_ptr: None,
133            storage_info_log_size: None,
134            storage_error_log_ptr: None,
135            storage_error_log_size: None,
136            storage_max_registers: self.max_registers,
137            storage_threads_per_block: self.threads_per_block,
138            storage_optimization_level: self.optimization_level.map(|value| value.clamp(0, 4)),
139        };
140
141        // Populate internal storage.
142        artifact.storage_target = self.target.map(Into::into);
143        artifact.storage_fallback = self.fallback_strategy.map(Into::into);
144        artifact.storage_cache_mode = self.cache_mode.map(Into::into);
145        artifact.storage_debug_info = self.generate_debug_info.map(i32::from);
146        artifact.storage_log_verbose = self.log_verbose.map(i32::from);
147        artifact.storage_line_info = self.generate_line_info.map(i32::from);
148        artifact.storage_info_log_ptr = self
149            .info_log_buffer
150            .as_mut()
151            .map(|slice| slice.as_mut_ptr().cast::<u8>());
152        artifact.storage_info_log_size = self
153            .info_log_buffer
154            .as_ref()
155            .map(|buffer| buffer.len().min(u32::MAX as usize) as u32);
156        artifact.storage_error_log_ptr = self
157            .error_log_buffer
158            .as_mut()
159            .map(|slice| slice.as_mut_ptr().cast::<u8>());
160        artifact.storage_error_log_size = self
161            .error_log_buffer
162            .as_ref()
163            .map(|buffer| buffer.len().min(u32::MAX as usize) as u32);
164
165        // Build the FFI arrays, pointing into `self`.
166        if let Some(ref mut val) = artifact.storage_max_registers {
167            artifact.names.push(JitOption::MaxRegisters.into());
168            artifact.values.push(ptr::from_mut::<u32>(val).cast());
169        }
170        if let Some(ref mut val) = artifact.storage_threads_per_block {
171            artifact.names.push(JitOption::ThreadsPerBlock.into());
172            artifact.values.push(ptr::from_mut::<u32>(val).cast());
173        }
174        // OUT parameters provided by user need direct pointer.
175        if let Some(ref mut val_ref) = self.wall_time {
176            artifact.names.push(JitOption::WallTime.into());
177            artifact.values.push(ptr::from_mut::<f32>(*val_ref).cast());
178        }
179
180        // Buffer handling: Add both pointer and size options if buffer is set.
181        if let Some(info_log_ptr) = artifact.storage_info_log_ptr {
182            artifact.names.push(JitOption::InfoLogBuffer.into());
183            artifact.values.push(info_log_ptr.cast());
184
185            if let Some(ref mut size_val) = artifact.storage_info_log_size {
186                artifact
187                    .names
188                    .push(JitOption::InfoLogBufferSizeBytes.into());
189                artifact.values.push(ptr::from_mut::<u32>(size_val).cast());
190            }
191        }
192        if let Some(error_log_ptr) = artifact.storage_error_log_ptr {
193            artifact.names.push(JitOption::ErrorLogBuffer.into());
194            artifact.values.push(error_log_ptr.cast());
195
196            if let Some(ref mut size_val) = artifact.storage_error_log_size {
197                artifact
198                    .names
199                    .push(JitOption::ErrorLogBufferSizeBytes.into());
200                artifact.values.push(ptr::from_mut::<u32>(size_val).cast());
201            }
202        }
203
204        if let Some(ref mut value) = artifact.storage_optimization_level {
205            artifact.names.push(JitOption::OptimizationLevel.into());
206            artifact.values.push(ptr::from_mut::<u32>(value).cast());
207        }
208        if self.target_from_cuda_context.is_some() {
209            artifact.names.push(JitOption::TargetFromCudaContext.into());
210            artifact.values.push(ptr::null_mut()); // No value needed.
211        }
212        // Point to internal storage for enums/bools converted to primitive types.
213        if let Some(ref mut val) = artifact.storage_target {
214            artifact.names.push(JitOption::Target.into());
215            artifact.values.push(ptr::from_mut::<u32>(val).cast());
216        }
217        if let Some(ref mut val) = artifact.storage_fallback {
218            artifact.names.push(JitOption::FallbackStrategy.into());
219            artifact.values.push(ptr::from_mut::<u32>(val).cast());
220        }
221        if let Some(ref mut val) = artifact.storage_debug_info {
222            artifact.names.push(JitOption::GenerateDebugInfo.into());
223            artifact.values.push(ptr::from_mut::<i32>(val).cast());
224        }
225        if let Some(ref mut val) = artifact.storage_log_verbose {
226            artifact.names.push(JitOption::LogVerbose.into());
227            artifact.values.push(ptr::from_mut::<i32>(val).cast());
228        }
229        if let Some(ref mut val) = artifact.storage_line_info {
230            artifact.names.push(JitOption::GenerateLineInfo.into());
231            artifact.values.push(ptr::from_mut::<i32>(val).cast());
232        }
233        if let Some(ref mut val) = artifact.storage_cache_mode {
234            artifact.names.push(JitOption::CacheMode.into());
235            artifact.values.push(ptr::from_mut::<u32>(val).cast());
236        }
237
238        artifact
239    }
240}
241
242#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
243#[repr(u32)]
244#[non_exhaustive]
245pub enum JitOption {
246    MaxRegisters = runtime::cudaJitOption::CU_JIT_MAX_REGISTERS as _,
247    ThreadsPerBlock = runtime::cudaJitOption::CU_JIT_THREADS_PER_BLOCK as _,
248    WallTime = runtime::cudaJitOption::CU_JIT_WALL_TIME as _,
249    InfoLogBuffer = runtime::cudaJitOption::CU_JIT_INFO_LOG_BUFFER as _,
250    InfoLogBufferSizeBytes = runtime::cudaJitOption::CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES as _,
251    ErrorLogBuffer = runtime::cudaJitOption::CU_JIT_ERROR_LOG_BUFFER as _,
252    ErrorLogBufferSizeBytes = runtime::cudaJitOption::CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES as _,
253    OptimizationLevel = runtime::cudaJitOption::CU_JIT_OPTIMIZATION_LEVEL as _,
254    TargetFromCudaContext = runtime::cudaJitOption::CU_JIT_TARGET_FROM_CUCONTEXT as _,
255    Target = runtime::cudaJitOption::CU_JIT_TARGET as _,
256    FallbackStrategy = runtime::cudaJitOption::CU_JIT_FALLBACK_STRATEGY as _,
257    GenerateDebugInfo = runtime::cudaJitOption::CU_JIT_GENERATE_DEBUG_INFO as _,
258    LogVerbose = runtime::cudaJitOption::CU_JIT_LOG_VERBOSE as _,
259    GenerateLineInfo = runtime::cudaJitOption::CU_JIT_GENERATE_LINE_INFO as _,
260    CacheMode = runtime::cudaJitOption::CU_JIT_CACHE_MODE as _,
261    #[deprecated]
262    NewSm3xOpt = runtime::cudaJitOption::CU_JIT_NEW_SM3X_OPT as _,
263    FastCompile = runtime::cudaJitOption::CU_JIT_FAST_COMPILE as _,
264    GlobalSymbolNames = runtime::cudaJitOption::CU_JIT_GLOBAL_SYMBOL_NAMES as _,
265    GlobalSymbolAddresses = runtime::cudaJitOption::CU_JIT_GLOBAL_SYMBOL_ADDRESSES as _,
266    GlobalSymbolCount = runtime::cudaJitOption::CU_JIT_GLOBAL_SYMBOL_COUNT as _,
267    #[deprecated]
268    Lto = runtime::cudaJitOption::CU_JIT_LTO as _,
269    #[deprecated]
270    Ftz = runtime::cudaJitOption::CU_JIT_FTZ as _,
271    #[deprecated]
272    PrecDiv = runtime::cudaJitOption::CU_JIT_PREC_DIV as _,
273    #[deprecated]
274    PrecSqrt = runtime::cudaJitOption::CU_JIT_PREC_SQRT as _,
275    #[deprecated]
276    Fma = runtime::cudaJitOption::CU_JIT_FMA as _,
277    #[deprecated]
278    ReferencedKernelNames = runtime::cudaJitOption::CU_JIT_REFERENCED_KERNEL_NAMES as _,
279    #[deprecated]
280    ReferencedKernelCount = runtime::cudaJitOption::CU_JIT_REFERENCED_KERNEL_COUNT as _,
281    #[deprecated]
282    ReferencedVariableNames = runtime::cudaJitOption::CU_JIT_REFERENCED_VARIABLE_NAMES as _,
283    #[deprecated]
284    ReferencedVariableCount = runtime::cudaJitOption::CU_JIT_REFERENCED_VARIABLE_COUNT as _,
285    #[deprecated]
286    OptimizeUnusedDeviceVariables =
287        runtime::cudaJitOption::CU_JIT_OPTIMIZE_UNUSED_DEVICE_VARIABLES as _,
288    NumOptions = runtime::cudaJitOption::CU_JIT_NUM_OPTIONS as _,
289}
290
291impl_enum_conversion!(u32, runtime::cudaJitOption, JitOption);
292
293impl_enum_display!(JitOption, {
294    Self::MaxRegisters => "CU_JIT_MAX_REGISTERS",
295    Self::ThreadsPerBlock => "CU_JIT_THREADS_PER_BLOCK",
296    Self::WallTime => "CU_JIT_WALL_TIME",
297    Self::InfoLogBuffer => "CU_JIT_INFO_LOG_BUFFER",
298    Self::InfoLogBufferSizeBytes => "CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES",
299    Self::ErrorLogBuffer => "CU_JIT_ERROR_LOG_BUFFER",
300    Self::ErrorLogBufferSizeBytes => "CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES",
301    Self::OptimizationLevel => "CU_JIT_OPTIMIZATION_LEVEL",
302    Self::TargetFromCudaContext => "CU_JIT_TARGET_FROM_CUCONTEXT",
303    Self::Target => "CU_JIT_TARGET",
304    Self::FallbackStrategy => "CU_JIT_FALLBACK_STRATEGY",
305    Self::GenerateDebugInfo => "CU_JIT_GENERATE_DEBUG_INFO",
306    Self::LogVerbose => "CU_JIT_LOG_VERBOSE",
307    Self::GenerateLineInfo => "CU_JIT_GENERATE_LINE_INFO",
308    Self::CacheMode => "CU_JIT_CACHE_MODE",
309    Self::NewSm3xOpt => "CU_JIT_NEW_SM3X_OPT",
310    Self::FastCompile => "CU_JIT_FAST_COMPILE",
311    Self::GlobalSymbolNames => "CU_JIT_GLOBAL_SYMBOL_NAMES",
312    Self::GlobalSymbolAddresses => "CU_JIT_GLOBAL_SYMBOL_ADDRESSES",
313    Self::GlobalSymbolCount => "CU_JIT_GLOBAL_SYMBOL_COUNT",
314    Self::Lto => "CU_JIT_LTO",
315    Self::Ftz => "CU_JIT_FTZ",
316    Self::PrecDiv => "CU_JIT_PREC_DIV",
317    Self::PrecSqrt => "CU_JIT_PREC_SQRT",
318    Self::Fma => "CU_JIT_FMA",
319    Self::ReferencedKernelNames => "CU_JIT_REFERENCED_KERNEL_NAMES",
320    Self::ReferencedKernelCount => "CU_JIT_REFERENCED_KERNEL_COUNT",
321    Self::ReferencedVariableNames => "CU_JIT_REFERENCED_VARIABLE_NAMES",
322    Self::ReferencedVariableCount => "CU_JIT_REFERENCED_VARIABLE_COUNT",
323    Self::OptimizeUnusedDeviceVariables => "CU_JIT_OPTIMIZE_UNUSED_DEVICE_VARIABLES",
324    Self::NumOptions => "CU_JIT_NUM_OPTIONS",
325});
326
327#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
328#[repr(u32)]
329#[non_exhaustive]
330pub enum JitTarget {
331    Compute30 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_30 as _,
332    Compute32 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_32 as _,
333    Compute35 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_35 as _,
334    Compute37 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_37 as _,
335    Compute50 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_50 as _,
336    Compute52 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_52 as _,
337    Compute53 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_53 as _,
338    Compute60 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_60 as _,
339    Compute61 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_61 as _,
340    Compute62 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_62 as _,
341    Compute70 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_70 as _,
342    Compute72 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_72 as _,
343    Compute75 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_75 as _,
344    Compute80 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_80 as _,
345    Compute86 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_86 as _,
346    Compute87 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_87 as _,
347    Compute89 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_89 as _,
348    Compute90 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_90 as _,
349}
350
351impl_enum_conversion!(u32, driver::CUjit_target, JitTarget);
352
353impl_enum_display!(JitTarget, {
354    Self::Compute30 => "CU_TARGET_COMPUTE_30",
355    Self::Compute32 => "CU_TARGET_COMPUTE_32",
356    Self::Compute35 => "CU_TARGET_COMPUTE_35",
357    Self::Compute37 => "CU_TARGET_COMPUTE_37",
358    Self::Compute50 => "CU_TARGET_COMPUTE_50",
359    Self::Compute52 => "CU_TARGET_COMPUTE_52",
360    Self::Compute53 => "CU_TARGET_COMPUTE_53",
361    Self::Compute60 => "CU_TARGET_COMPUTE_60",
362    Self::Compute61 => "CU_TARGET_COMPUTE_61",
363    Self::Compute62 => "CU_TARGET_COMPUTE_62",
364    Self::Compute70 => "CU_TARGET_COMPUTE_70",
365    Self::Compute72 => "CU_TARGET_COMPUTE_72",
366    Self::Compute75 => "CU_TARGET_COMPUTE_75",
367    Self::Compute80 => "CU_TARGET_COMPUTE_80",
368    Self::Compute86 => "CU_TARGET_COMPUTE_86",
369    Self::Compute87 => "CU_TARGET_COMPUTE_87",
370    Self::Compute89 => "CU_TARGET_COMPUTE_89",
371    Self::Compute90 => "CU_TARGET_COMPUTE_90",
372});
373
374#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
375#[repr(u32)]
376#[non_exhaustive]
377pub enum JitFallback {
378    PreferPtx = driver::CUjit_fallback_enum::CU_PREFER_PTX as _,
379    PreferBinary = driver::CUjit_fallback_enum::CU_PREFER_BINARY as _,
380}
381
382impl_enum_conversion!(u32, driver::CUjit_fallback, JitFallback);
383
384impl_enum_display!(JitFallback, {
385    Self::PreferPtx => "CU_PREFER_PTX",
386    Self::PreferBinary => "CU_PREFER_BINARY",
387});
388
389#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
390#[repr(u32)]
391#[non_exhaustive]
392pub enum JitCacheMode {
393    OptionNone = driver::CUjit_cacheMode_enum::CU_JIT_CACHE_OPTION_NONE as _,
394    OptionCg = driver::CUjit_cacheMode_enum::CU_JIT_CACHE_OPTION_CG as _,
395    OptionCa = driver::CUjit_cacheMode_enum::CU_JIT_CACHE_OPTION_CA as _,
396}
397
398impl_enum_conversion!(u32, driver::CUjit_cacheMode, JitCacheMode);
399
400impl_enum_display!(JitCacheMode, {
401    Self::OptionNone => "CU_JIT_CACHE_OPTION_NONE",
402    Self::OptionCg => "CU_JIT_CACHE_OPTION_CG",
403    Self::OptionCa => "CU_JIT_CACHE_OPTION_CA",
404});
405
406#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
407#[repr(u32)]
408#[non_exhaustive]
409pub enum JitInputType {
410    Cubin = driver::CUjitInputType_enum::CU_JIT_INPUT_CUBIN as _,
411    Ptx = driver::CUjitInputType_enum::CU_JIT_INPUT_PTX as _,
412    Fatbinary = driver::CUjitInputType_enum::CU_JIT_INPUT_FATBINARY as _,
413    Object = driver::CUjitInputType_enum::CU_JIT_INPUT_OBJECT as _,
414    Library = driver::CUjitInputType_enum::CU_JIT_INPUT_LIBRARY as _,
415    #[deprecated]
416    Nvvm = driver::CUjitInputType_enum::CU_JIT_INPUT_NVVM as _,
417    NumInputTypes = driver::CUjitInputType_enum::CU_JIT_NUM_INPUT_TYPES as _,
418}
419
420impl_enum_conversion!(u32, driver::CUjitInputType, JitInputType);
421
422impl_enum_display!(JitInputType, {
423    Self::Cubin => "CU_JIT_INPUT_CUBIN",
424    Self::Ptx => "CU_JIT_INPUT_PTX",
425    Self::Fatbinary => "CU_JIT_INPUT_FATBINARY",
426    Self::Object => "CU_JIT_INPUT_OBJECT",
427    Self::Library => "CU_JIT_INPUT_LIBRARY",
428    Self::Nvvm => "CU_JIT_INPUT_NVVM",
429    Self::NumInputTypes => "CU_JIT_NUM_INPUT_TYPES",
430});