1#![allow(deprecated)]
2
3use std::ptr;
4
5use num_enum::{IntoPrimitive, TryFromPrimitive};
6use singe_cuda_sys::{driver, runtime};
7
8use singe_core::{impl_enum_conversion, impl_enum_display};
9
10#[derive(Debug, Default)]
11pub struct JitOptions<'a> {
12 pub max_registers: Option<u32>,
13 pub threads_per_block: Option<u32>,
14 pub wall_time: Option<&'a mut f32>,
15 pub info_log_buffer: Option<&'a mut [u8]>,
16 pub error_log_buffer: Option<&'a mut [u8]>,
17 pub optimization_level: Option<u32>,
18 pub target_from_cuda_context: Option<()>,
19 pub target: Option<JitTarget>,
20 pub fallback_strategy: Option<JitFallback>,
21 pub generate_debug_info: Option<bool>,
22 pub log_verbose: Option<bool>,
23 pub generate_line_info: Option<bool>,
24 pub cache_mode: Option<JitCacheMode>,
25}
26
27pub struct JitOptionsArtifact {
28 pub names: Vec<driver::CUjit_option>,
29 pub values: Vec<*mut ()>,
30
31 storage_target: Option<u32>,
33 storage_fallback: Option<u32>,
34 storage_debug_info: Option<i32>,
35 storage_log_verbose: Option<i32>,
36 storage_line_info: Option<i32>,
37 storage_cache_mode: Option<u32>,
38
39 storage_info_log_ptr: Option<*mut u8>,
41 storage_info_log_size: Option<u32>,
42 storage_error_log_ptr: Option<*mut u8>,
43 storage_error_log_size: Option<u32>,
44 storage_max_registers: Option<u32>,
45 storage_threads_per_block: Option<u32>,
46 storage_optimization_level: Option<u32>,
47}
48
49impl<'a> JitOptions<'a> {
50 pub fn with_max_registers(mut self, value: u32) -> Self {
51 self.max_registers = Some(value);
52 self
53 }
54
55 pub fn with_threads_per_block(mut self, value: u32) -> Self {
56 self.threads_per_block = Some(value);
57 self
58 }
59
60 pub fn with_wall_time(mut self, value: &'a mut f32) -> Self {
61 self.wall_time = Some(value);
62 self
63 }
64
65 pub fn with_info_log(mut self, buffer: &'a mut [u8]) -> Self {
66 self.info_log_buffer = Some(buffer);
67 self
68 }
69
70 pub fn with_error_log(mut self, buffer: &'a mut [u8]) -> Self {
71 self.error_log_buffer = Some(buffer);
72 self
73 }
74
75 pub fn with_optimization_level(mut self, level: u32) -> Self {
76 self.optimization_level = Some(level.min(4));
78 self
79 }
80
81 pub const fn with_target_from_cuda_context(mut self) -> Self {
82 self.target_from_cuda_context = Some(());
83 self
84 }
85
86 pub fn with_target(mut self, target: JitTarget) -> Self {
87 self.target = Some(target);
88 self
89 }
90
91 pub fn with_fallback_strategy(mut self, strategy: JitFallback) -> Self {
92 self.fallback_strategy = Some(strategy);
93 self
94 }
95
96 pub fn with_generate_debug_info(mut self, enable: bool) -> Self {
97 self.generate_debug_info = Some(enable);
98 self
99 }
100
101 pub fn with_log_verbose(mut self, enable: bool) -> Self {
102 self.log_verbose = Some(enable);
103 self
104 }
105
106 pub fn with_generate_line_info(mut self, enable: bool) -> Self {
107 self.generate_line_info = Some(enable);
108 self
109 }
110
111 pub fn with_cache_mode(mut self, mode: JitCacheMode) -> Self {
112 self.cache_mode = Some(mode);
113 self
114 }
115
116 pub fn build(&mut self) -> JitOptionsArtifact {
123 let mut artifact = JitOptionsArtifact {
124 names: Vec::new(),
125 values: Vec::new(),
126 storage_target: None,
127 storage_fallback: None,
128 storage_debug_info: None,
129 storage_log_verbose: None,
130 storage_line_info: None,
131 storage_cache_mode: None,
132 storage_info_log_ptr: None,
133 storage_info_log_size: None,
134 storage_error_log_ptr: None,
135 storage_error_log_size: None,
136 storage_max_registers: self.max_registers,
137 storage_threads_per_block: self.threads_per_block,
138 storage_optimization_level: self.optimization_level.map(|value| value.clamp(0, 4)),
139 };
140
141 artifact.storage_target = self.target.map(Into::into);
143 artifact.storage_fallback = self.fallback_strategy.map(Into::into);
144 artifact.storage_cache_mode = self.cache_mode.map(Into::into);
145 artifact.storage_debug_info = self.generate_debug_info.map(i32::from);
146 artifact.storage_log_verbose = self.log_verbose.map(i32::from);
147 artifact.storage_line_info = self.generate_line_info.map(i32::from);
148 artifact.storage_info_log_ptr = self
149 .info_log_buffer
150 .as_mut()
151 .map(|slice| slice.as_mut_ptr().cast::<u8>());
152 artifact.storage_info_log_size = self
153 .info_log_buffer
154 .as_ref()
155 .map(|buffer| buffer.len().min(u32::MAX as usize) as u32);
156 artifact.storage_error_log_ptr = self
157 .error_log_buffer
158 .as_mut()
159 .map(|slice| slice.as_mut_ptr().cast::<u8>());
160 artifact.storage_error_log_size = self
161 .error_log_buffer
162 .as_ref()
163 .map(|buffer| buffer.len().min(u32::MAX as usize) as u32);
164
165 if let Some(ref mut val) = artifact.storage_max_registers {
167 artifact.names.push(JitOption::MaxRegisters.into());
168 artifact.values.push(ptr::from_mut::<u32>(val).cast());
169 }
170 if let Some(ref mut val) = artifact.storage_threads_per_block {
171 artifact.names.push(JitOption::ThreadsPerBlock.into());
172 artifact.values.push(ptr::from_mut::<u32>(val).cast());
173 }
174 if let Some(ref mut val_ref) = self.wall_time {
176 artifact.names.push(JitOption::WallTime.into());
177 artifact.values.push(ptr::from_mut::<f32>(*val_ref).cast());
178 }
179
180 if let Some(info_log_ptr) = artifact.storage_info_log_ptr {
182 artifact.names.push(JitOption::InfoLogBuffer.into());
183 artifact.values.push(info_log_ptr.cast());
184
185 if let Some(ref mut size_val) = artifact.storage_info_log_size {
186 artifact
187 .names
188 .push(JitOption::InfoLogBufferSizeBytes.into());
189 artifact.values.push(ptr::from_mut::<u32>(size_val).cast());
190 }
191 }
192 if let Some(error_log_ptr) = artifact.storage_error_log_ptr {
193 artifact.names.push(JitOption::ErrorLogBuffer.into());
194 artifact.values.push(error_log_ptr.cast());
195
196 if let Some(ref mut size_val) = artifact.storage_error_log_size {
197 artifact
198 .names
199 .push(JitOption::ErrorLogBufferSizeBytes.into());
200 artifact.values.push(ptr::from_mut::<u32>(size_val).cast());
201 }
202 }
203
204 if let Some(ref mut value) = artifact.storage_optimization_level {
205 artifact.names.push(JitOption::OptimizationLevel.into());
206 artifact.values.push(ptr::from_mut::<u32>(value).cast());
207 }
208 if self.target_from_cuda_context.is_some() {
209 artifact.names.push(JitOption::TargetFromCudaContext.into());
210 artifact.values.push(ptr::null_mut()); }
212 if let Some(ref mut val) = artifact.storage_target {
214 artifact.names.push(JitOption::Target.into());
215 artifact.values.push(ptr::from_mut::<u32>(val).cast());
216 }
217 if let Some(ref mut val) = artifact.storage_fallback {
218 artifact.names.push(JitOption::FallbackStrategy.into());
219 artifact.values.push(ptr::from_mut::<u32>(val).cast());
220 }
221 if let Some(ref mut val) = artifact.storage_debug_info {
222 artifact.names.push(JitOption::GenerateDebugInfo.into());
223 artifact.values.push(ptr::from_mut::<i32>(val).cast());
224 }
225 if let Some(ref mut val) = artifact.storage_log_verbose {
226 artifact.names.push(JitOption::LogVerbose.into());
227 artifact.values.push(ptr::from_mut::<i32>(val).cast());
228 }
229 if let Some(ref mut val) = artifact.storage_line_info {
230 artifact.names.push(JitOption::GenerateLineInfo.into());
231 artifact.values.push(ptr::from_mut::<i32>(val).cast());
232 }
233 if let Some(ref mut val) = artifact.storage_cache_mode {
234 artifact.names.push(JitOption::CacheMode.into());
235 artifact.values.push(ptr::from_mut::<u32>(val).cast());
236 }
237
238 artifact
239 }
240}
241
242#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
243#[repr(u32)]
244#[non_exhaustive]
245pub enum JitOption {
246 MaxRegisters = runtime::cudaJitOption::CU_JIT_MAX_REGISTERS as _,
247 ThreadsPerBlock = runtime::cudaJitOption::CU_JIT_THREADS_PER_BLOCK as _,
248 WallTime = runtime::cudaJitOption::CU_JIT_WALL_TIME as _,
249 InfoLogBuffer = runtime::cudaJitOption::CU_JIT_INFO_LOG_BUFFER as _,
250 InfoLogBufferSizeBytes = runtime::cudaJitOption::CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES as _,
251 ErrorLogBuffer = runtime::cudaJitOption::CU_JIT_ERROR_LOG_BUFFER as _,
252 ErrorLogBufferSizeBytes = runtime::cudaJitOption::CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES as _,
253 OptimizationLevel = runtime::cudaJitOption::CU_JIT_OPTIMIZATION_LEVEL as _,
254 TargetFromCudaContext = runtime::cudaJitOption::CU_JIT_TARGET_FROM_CUCONTEXT as _,
255 Target = runtime::cudaJitOption::CU_JIT_TARGET as _,
256 FallbackStrategy = runtime::cudaJitOption::CU_JIT_FALLBACK_STRATEGY as _,
257 GenerateDebugInfo = runtime::cudaJitOption::CU_JIT_GENERATE_DEBUG_INFO as _,
258 LogVerbose = runtime::cudaJitOption::CU_JIT_LOG_VERBOSE as _,
259 GenerateLineInfo = runtime::cudaJitOption::CU_JIT_GENERATE_LINE_INFO as _,
260 CacheMode = runtime::cudaJitOption::CU_JIT_CACHE_MODE as _,
261 #[deprecated]
262 NewSm3xOpt = runtime::cudaJitOption::CU_JIT_NEW_SM3X_OPT as _,
263 FastCompile = runtime::cudaJitOption::CU_JIT_FAST_COMPILE as _,
264 GlobalSymbolNames = runtime::cudaJitOption::CU_JIT_GLOBAL_SYMBOL_NAMES as _,
265 GlobalSymbolAddresses = runtime::cudaJitOption::CU_JIT_GLOBAL_SYMBOL_ADDRESSES as _,
266 GlobalSymbolCount = runtime::cudaJitOption::CU_JIT_GLOBAL_SYMBOL_COUNT as _,
267 #[deprecated]
268 Lto = runtime::cudaJitOption::CU_JIT_LTO as _,
269 #[deprecated]
270 Ftz = runtime::cudaJitOption::CU_JIT_FTZ as _,
271 #[deprecated]
272 PrecDiv = runtime::cudaJitOption::CU_JIT_PREC_DIV as _,
273 #[deprecated]
274 PrecSqrt = runtime::cudaJitOption::CU_JIT_PREC_SQRT as _,
275 #[deprecated]
276 Fma = runtime::cudaJitOption::CU_JIT_FMA as _,
277 #[deprecated]
278 ReferencedKernelNames = runtime::cudaJitOption::CU_JIT_REFERENCED_KERNEL_NAMES as _,
279 #[deprecated]
280 ReferencedKernelCount = runtime::cudaJitOption::CU_JIT_REFERENCED_KERNEL_COUNT as _,
281 #[deprecated]
282 ReferencedVariableNames = runtime::cudaJitOption::CU_JIT_REFERENCED_VARIABLE_NAMES as _,
283 #[deprecated]
284 ReferencedVariableCount = runtime::cudaJitOption::CU_JIT_REFERENCED_VARIABLE_COUNT as _,
285 #[deprecated]
286 OptimizeUnusedDeviceVariables =
287 runtime::cudaJitOption::CU_JIT_OPTIMIZE_UNUSED_DEVICE_VARIABLES as _,
288 NumOptions = runtime::cudaJitOption::CU_JIT_NUM_OPTIONS as _,
289}
290
291impl_enum_conversion!(u32, runtime::cudaJitOption, JitOption);
292
293impl_enum_display!(JitOption, {
294 Self::MaxRegisters => "CU_JIT_MAX_REGISTERS",
295 Self::ThreadsPerBlock => "CU_JIT_THREADS_PER_BLOCK",
296 Self::WallTime => "CU_JIT_WALL_TIME",
297 Self::InfoLogBuffer => "CU_JIT_INFO_LOG_BUFFER",
298 Self::InfoLogBufferSizeBytes => "CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES",
299 Self::ErrorLogBuffer => "CU_JIT_ERROR_LOG_BUFFER",
300 Self::ErrorLogBufferSizeBytes => "CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES",
301 Self::OptimizationLevel => "CU_JIT_OPTIMIZATION_LEVEL",
302 Self::TargetFromCudaContext => "CU_JIT_TARGET_FROM_CUCONTEXT",
303 Self::Target => "CU_JIT_TARGET",
304 Self::FallbackStrategy => "CU_JIT_FALLBACK_STRATEGY",
305 Self::GenerateDebugInfo => "CU_JIT_GENERATE_DEBUG_INFO",
306 Self::LogVerbose => "CU_JIT_LOG_VERBOSE",
307 Self::GenerateLineInfo => "CU_JIT_GENERATE_LINE_INFO",
308 Self::CacheMode => "CU_JIT_CACHE_MODE",
309 Self::NewSm3xOpt => "CU_JIT_NEW_SM3X_OPT",
310 Self::FastCompile => "CU_JIT_FAST_COMPILE",
311 Self::GlobalSymbolNames => "CU_JIT_GLOBAL_SYMBOL_NAMES",
312 Self::GlobalSymbolAddresses => "CU_JIT_GLOBAL_SYMBOL_ADDRESSES",
313 Self::GlobalSymbolCount => "CU_JIT_GLOBAL_SYMBOL_COUNT",
314 Self::Lto => "CU_JIT_LTO",
315 Self::Ftz => "CU_JIT_FTZ",
316 Self::PrecDiv => "CU_JIT_PREC_DIV",
317 Self::PrecSqrt => "CU_JIT_PREC_SQRT",
318 Self::Fma => "CU_JIT_FMA",
319 Self::ReferencedKernelNames => "CU_JIT_REFERENCED_KERNEL_NAMES",
320 Self::ReferencedKernelCount => "CU_JIT_REFERENCED_KERNEL_COUNT",
321 Self::ReferencedVariableNames => "CU_JIT_REFERENCED_VARIABLE_NAMES",
322 Self::ReferencedVariableCount => "CU_JIT_REFERENCED_VARIABLE_COUNT",
323 Self::OptimizeUnusedDeviceVariables => "CU_JIT_OPTIMIZE_UNUSED_DEVICE_VARIABLES",
324 Self::NumOptions => "CU_JIT_NUM_OPTIONS",
325});
326
327#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
328#[repr(u32)]
329#[non_exhaustive]
330pub enum JitTarget {
331 Compute30 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_30 as _,
332 Compute32 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_32 as _,
333 Compute35 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_35 as _,
334 Compute37 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_37 as _,
335 Compute50 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_50 as _,
336 Compute52 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_52 as _,
337 Compute53 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_53 as _,
338 Compute60 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_60 as _,
339 Compute61 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_61 as _,
340 Compute62 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_62 as _,
341 Compute70 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_70 as _,
342 Compute72 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_72 as _,
343 Compute75 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_75 as _,
344 Compute80 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_80 as _,
345 Compute86 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_86 as _,
346 Compute87 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_87 as _,
347 Compute89 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_89 as _,
348 Compute90 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_90 as _,
349}
350
351impl_enum_conversion!(u32, driver::CUjit_target, JitTarget);
352
353impl_enum_display!(JitTarget, {
354 Self::Compute30 => "CU_TARGET_COMPUTE_30",
355 Self::Compute32 => "CU_TARGET_COMPUTE_32",
356 Self::Compute35 => "CU_TARGET_COMPUTE_35",
357 Self::Compute37 => "CU_TARGET_COMPUTE_37",
358 Self::Compute50 => "CU_TARGET_COMPUTE_50",
359 Self::Compute52 => "CU_TARGET_COMPUTE_52",
360 Self::Compute53 => "CU_TARGET_COMPUTE_53",
361 Self::Compute60 => "CU_TARGET_COMPUTE_60",
362 Self::Compute61 => "CU_TARGET_COMPUTE_61",
363 Self::Compute62 => "CU_TARGET_COMPUTE_62",
364 Self::Compute70 => "CU_TARGET_COMPUTE_70",
365 Self::Compute72 => "CU_TARGET_COMPUTE_72",
366 Self::Compute75 => "CU_TARGET_COMPUTE_75",
367 Self::Compute80 => "CU_TARGET_COMPUTE_80",
368 Self::Compute86 => "CU_TARGET_COMPUTE_86",
369 Self::Compute87 => "CU_TARGET_COMPUTE_87",
370 Self::Compute89 => "CU_TARGET_COMPUTE_89",
371 Self::Compute90 => "CU_TARGET_COMPUTE_90",
372});
373
374#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
375#[repr(u32)]
376#[non_exhaustive]
377pub enum JitFallback {
378 PreferPtx = driver::CUjit_fallback_enum::CU_PREFER_PTX as _,
379 PreferBinary = driver::CUjit_fallback_enum::CU_PREFER_BINARY as _,
380}
381
382impl_enum_conversion!(u32, driver::CUjit_fallback, JitFallback);
383
384impl_enum_display!(JitFallback, {
385 Self::PreferPtx => "CU_PREFER_PTX",
386 Self::PreferBinary => "CU_PREFER_BINARY",
387});
388
389#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
390#[repr(u32)]
391#[non_exhaustive]
392pub enum JitCacheMode {
393 OptionNone = driver::CUjit_cacheMode_enum::CU_JIT_CACHE_OPTION_NONE as _,
394 OptionCg = driver::CUjit_cacheMode_enum::CU_JIT_CACHE_OPTION_CG as _,
395 OptionCa = driver::CUjit_cacheMode_enum::CU_JIT_CACHE_OPTION_CA as _,
396}
397
398impl_enum_conversion!(u32, driver::CUjit_cacheMode, JitCacheMode);
399
400impl_enum_display!(JitCacheMode, {
401 Self::OptionNone => "CU_JIT_CACHE_OPTION_NONE",
402 Self::OptionCg => "CU_JIT_CACHE_OPTION_CG",
403 Self::OptionCa => "CU_JIT_CACHE_OPTION_CA",
404});
405
406#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
407#[repr(u32)]
408#[non_exhaustive]
409pub enum JitInputType {
410 Cubin = driver::CUjitInputType_enum::CU_JIT_INPUT_CUBIN as _,
411 Ptx = driver::CUjitInputType_enum::CU_JIT_INPUT_PTX as _,
412 Fatbinary = driver::CUjitInputType_enum::CU_JIT_INPUT_FATBINARY as _,
413 Object = driver::CUjitInputType_enum::CU_JIT_INPUT_OBJECT as _,
414 Library = driver::CUjitInputType_enum::CU_JIT_INPUT_LIBRARY as _,
415 #[deprecated]
416 Nvvm = driver::CUjitInputType_enum::CU_JIT_INPUT_NVVM as _,
417 NumInputTypes = driver::CUjitInputType_enum::CU_JIT_NUM_INPUT_TYPES as _,
418}
419
420impl_enum_conversion!(u32, driver::CUjitInputType, JitInputType);
421
422impl_enum_display!(JitInputType, {
423 Self::Cubin => "CU_JIT_INPUT_CUBIN",
424 Self::Ptx => "CU_JIT_INPUT_PTX",
425 Self::Fatbinary => "CU_JIT_INPUT_FATBINARY",
426 Self::Object => "CU_JIT_INPUT_OBJECT",
427 Self::Library => "CU_JIT_INPUT_LIBRARY",
428 Self::Nvvm => "CU_JIT_INPUT_NVVM",
429 Self::NumInputTypes => "CU_JIT_NUM_INPUT_TYPES",
430});