rustacuda/function.rs
1//! Functions and types for working with CUDA kernels.
2
3use crate::context::{CacheConfig, SharedMemoryConfig};
4use crate::error::{CudaResult, ToResult};
5use crate::module::Module;
6use cuda_driver_sys::CUfunction;
7use std::marker::PhantomData;
8use std::mem::transmute;
9
10/// Dimensions of a grid, or the number of thread blocks in a kernel launch.
11///
12/// Each component of a `GridSize` must be at least 1. The maximum size depends on your device's
13/// compute capability, but maximums of `x = (2^31)-1, y = 65535, z = 65535` are common. Launching
14/// a kernel with a grid size greater than these limits will cause an error.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct GridSize {
17 /// Width of grid in blocks
18 pub x: u32,
19 /// Height of grid in blocks
20 pub y: u32,
21 /// Depth of grid in blocks
22 pub z: u32,
23}
24impl GridSize {
25 /// Create a one-dimensional grid of `x` blocks
26 #[inline]
27 pub fn x(x: u32) -> GridSize {
28 GridSize { x, y: 1, z: 1 }
29 }
30
31 /// Create a two-dimensional grid of `x * y` blocks
32 #[inline]
33 pub fn xy(x: u32, y: u32) -> GridSize {
34 GridSize { x, y, z: 1 }
35 }
36
37 /// Create a three-dimensional grid of `x * y * z` blocks
38 #[inline]
39 pub fn xyz(x: u32, y: u32, z: u32) -> GridSize {
40 GridSize { x, y, z }
41 }
42}
43impl From<u32> for GridSize {
44 fn from(x: u32) -> GridSize {
45 GridSize::x(x)
46 }
47}
48impl From<(u32, u32)> for GridSize {
49 fn from((x, y): (u32, u32)) -> GridSize {
50 GridSize::xy(x, y)
51 }
52}
53impl From<(u32, u32, u32)> for GridSize {
54 fn from((x, y, z): (u32, u32, u32)) -> GridSize {
55 GridSize::xyz(x, y, z)
56 }
57}
58impl<'a> From<&'a GridSize> for GridSize {
59 fn from(other: &GridSize) -> GridSize {
60 other.clone()
61 }
62}
63
64/// Dimensions of a thread block, or the number of threads in a block.
65///
66/// Each component of a `BlockSize` must be at least 1. The maximum size depends on your device's
67/// compute capability, but maximums of `x = 1024, y = 1024, z = 64` are common. In addition, the
68/// limit on total number of threads in a block (`x * y * z`) is also defined by the compute
69/// capability, typically 1024. Launching a kernel with a block size greater than these limits will
70/// cause an error.
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct BlockSize {
73 /// X dimension of each thread block
74 pub x: u32,
75 /// Y dimension of each thread block
76 pub y: u32,
77 /// Z dimension of each thread block
78 pub z: u32,
79}
80impl BlockSize {
81 /// Create a one-dimensional block of `x` threads
82 #[inline]
83 pub fn x(x: u32) -> BlockSize {
84 BlockSize { x, y: 1, z: 1 }
85 }
86
87 /// Create a two-dimensional block of `x * y` threads
88 #[inline]
89 pub fn xy(x: u32, y: u32) -> BlockSize {
90 BlockSize { x, y, z: 1 }
91 }
92
93 /// Create a three-dimensional block of `x * y * z` threads
94 #[inline]
95 pub fn xyz(x: u32, y: u32, z: u32) -> BlockSize {
96 BlockSize { x, y, z }
97 }
98}
99impl From<u32> for BlockSize {
100 fn from(x: u32) -> BlockSize {
101 BlockSize::x(x)
102 }
103}
104impl From<(u32, u32)> for BlockSize {
105 fn from((x, y): (u32, u32)) -> BlockSize {
106 BlockSize::xy(x, y)
107 }
108}
109impl From<(u32, u32, u32)> for BlockSize {
110 fn from((x, y, z): (u32, u32, u32)) -> BlockSize {
111 BlockSize::xyz(x, y, z)
112 }
113}
114impl<'a> From<&'a BlockSize> for BlockSize {
115 fn from(other: &BlockSize) -> BlockSize {
116 other.clone()
117 }
118}
119
120/// All supported function attributes for [Function::get_attribute](struct.Function.html#method.get_attribute)
121#[repr(u32)]
122#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
123pub enum FunctionAttribute {
124 /// The maximum number of threads per block, beyond which a launch would fail. This depends on
125 /// both the function and the device.
126 MaxThreadsPerBlock = 0,
127
128 /// The size in bytes of the statically-allocated shared memory required by this function.
129 SharedMemorySizeBytes = 1,
130
131 /// The size in bytes of the constant memory required by this function
132 ConstSizeBytes = 2,
133
134 /// The size in bytes of local memory used by each thread of this function
135 LocalSizeBytes = 3,
136
137 /// The number of registers used by each thread of this function
138 NumRegisters = 4,
139
140 /// The PTX virtual architecture version for which the function was compiled. This value is the
141 /// major PTX version * 10 + the minor PTX version, so version 1.3 would return the value 13.
142 PtxVersion = 5,
143
144 /// The binary architecture version for which the function was compiled. Encoded the same way as
145 /// PtxVersion.
146 BinaryVersion = 6,
147
148 /// The attribute to indicate whether the function has been compiled with user specified
149 /// option "-Xptxas --dlcm=ca" set.
150 CacheModeCa = 7,
151
152 #[doc(hidden)]
153 __Nonexhaustive = 8,
154}
155
156/// Handle to a global kernel function.
157#[derive(Debug)]
158pub struct Function<'a> {
159 inner: CUfunction,
160 module: PhantomData<&'a Module>,
161}
162impl<'a> Function<'a> {
163 pub(crate) fn new(inner: CUfunction, _module: &Module) -> Function {
164 Function {
165 inner,
166 module: PhantomData,
167 }
168 }
169
170 /// Returns information about a function.
171 ///
172 /// # Examples
173 ///
174 /// ```
175 /// # use rustacuda::*;
176 /// # use std::error::Error;
177 /// # fn main() -> Result<(), Box<dyn Error>> {
178 /// # let _ctx = quick_init()?;
179 /// # use rustacuda::module::Module;
180 /// # use std::ffi::CString;
181 /// # let ptx = CString::new(include_str!("../resources/add.ptx"))?;
182 /// # let module = Module::load_from_string(&ptx)?;
183 /// # let name = CString::new("sum")?;
184 /// use rustacuda::function::FunctionAttribute;
185 /// let function = module.get_function(&name)?;
186 /// let shared_memory = function.get_attribute(FunctionAttribute::SharedMemorySizeBytes)?;
187 /// println!("This function uses {} bytes of shared memory", shared_memory);
188 /// # Ok(())
189 /// # }
190 /// ```
191 pub fn get_attribute(&self, attr: FunctionAttribute) -> CudaResult<i32> {
192 unsafe {
193 let mut val = 0i32;
194 cuda_driver_sys::cuFuncGetAttribute(
195 &mut val as *mut i32,
196 // This should be safe, as the repr and values of FunctionAttribute should match.
197 ::std::mem::transmute(attr),
198 self.inner,
199 )
200 .to_result()?;
201 Ok(val)
202 }
203 }
204
205 /// Sets the preferred cache configuration for this function.
206 ///
207 /// On devices where L1 cache and shared memory use the same hardware resources, this sets the
208 /// preferred cache configuration for this function. This is only a preference. The
209 /// driver will use the requested configuration if possible, but is free to choose a different
210 /// configuration if required to execute the function. This setting will override the
211 /// context-wide setting.
212 ///
213 /// This setting does nothing on devices where the size of the L1 cache and shared memory are
214 /// fixed.
215 ///
216 /// # Example
217 ///
218 /// ```
219 /// # use rustacuda::*;
220 /// # use std::error::Error;
221 /// # fn main() -> Result<(), Box<dyn Error>> {
222 /// # let _ctx = quick_init()?;
223 /// # use rustacuda::module::Module;
224 /// # use std::ffi::CString;
225 /// # let ptx = CString::new(include_str!("../resources/add.ptx"))?;
226 /// # let module = Module::load_from_string(&ptx)?;
227 /// # let name = CString::new("sum")?;
228 /// use rustacuda::context::CacheConfig;
229 /// let mut function = module.get_function(&name)?;
230 /// function.set_cache_config(CacheConfig::PreferL1)?;
231 /// # Ok(())
232 /// # }
233 /// ```
234 pub fn set_cache_config(&mut self, config: CacheConfig) -> CudaResult<()> {
235 unsafe { cuda_driver_sys::cuFuncSetCacheConfig(self.inner, transmute(config)).to_result() }
236 }
237
238 /// Sets the preferred shared memory configuration for this function.
239 ///
240 /// On devices with configurable shared memory banks, this function will set this function's
241 /// shared memory bank size which is used for subsequent launches of this function. If not set,
242 /// the context-wide setting will be used instead.
243 ///
244 /// # Example
245 ///
246 /// ```
247 /// # use rustacuda::*;
248 /// # use std::error::Error;
249 /// # fn main() -> Result<(), Box<dyn Error>> {
250 /// # let _ctx = quick_init()?;
251 /// # use rustacuda::module::Module;
252 /// # use std::ffi::CString;
253 /// # let ptx = CString::new(include_str!("../resources/add.ptx"))?;
254 /// # let module = Module::load_from_string(&ptx)?;
255 /// # let name = CString::new("sum")?;
256 /// use rustacuda::context::SharedMemoryConfig;
257 /// let mut function = module.get_function(&name)?;
258 /// function.set_shared_memory_config(SharedMemoryConfig::EightByteBankSize)?;
259 /// # Ok(())
260 /// # }
261 /// ```
262 pub fn set_shared_memory_config(&mut self, cfg: SharedMemoryConfig) -> CudaResult<()> {
263 unsafe { cuda_driver_sys::cuFuncSetSharedMemConfig(self.inner, transmute(cfg)).to_result() }
264 }
265
266 pub(crate) fn to_inner(&self) -> CUfunction {
267 self.inner
268 }
269}
270
271/// Launch a kernel function asynchronously.
272///
273/// # Syntax:
274///
275/// The format of this macro is designed to resemble the triple-chevron syntax used to launch
276/// kernels in CUDA C. There are two forms available:
277///
278/// ```ignore
279/// let result = launch!(module.function_name<<<grid, block, shared_memory_size, stream>>>(parameter1, parameter2...));
280/// ```
281///
282/// This will load a kernel called `function_name` from the module `module` and launch it with
283/// the given grid/block size on the given stream. Unlike in CUDA C, the shared memory size and
284/// stream parameters are not optional. The shared memory size is a number of bytes per thread for
285/// dynamic shared memory (Note that this uses `extern __shared__ int x[]` in CUDA C, not the
286/// fixed-length arrays created by `__shared__ int x[64]`. This will usually be zero.).
287/// `stream` must be the name of a [`Stream`](stream/struct.Stream.html) value.
288/// `grid` can be any value which implements [`Into<GridSize>`](function/struct.GridSize.html) (such as
289/// `u32` values, tuples of up to three `u32` values, and GridSize structures) and likewise `block`
290/// can be any value that implements [`Into<BlockSize>`](function/struct.BlockSize.html).
291///
292/// NOTE: due to some limitations of Rust's macro system, `module` and `stream` must be local
293/// variable names. Paths or function calls will not work.
294///
295/// The second form is similar:
296///
297/// ```ignore
298/// let result = launch!(function<<<grid, block, shared_memory_size, stream>>>(parameter1, parameter2...));
299/// ```
300///
301/// In this variant, the `function` parameter must be a variable. Use this form to avoid looking up
302/// the kernel function for each call.
303///
304/// # Safety
305///
306/// Launching kernels must be done in an `unsafe` block. Calling a kernel is similar to calling a
307/// foreign-language function, as the kernel itself could be written in C or unsafe Rust. The kernel
308/// must accept the same number and type of parameters that are passed to the `launch!` macro. The
309/// kernel must not write invalid data (for example, invalid enums) into areas of memory that can
310/// be copied back to the host. The programmer must ensure that the host does not access device or
311/// unified memory that the kernel could write to until after calling `stream.synchronize()`.
312///
313/// # Examples
314///
315/// ```
316/// # #[macro_use]
317/// # use rustacuda::*;
318/// # use std::error::Error;
319/// use rustacuda::memory::*;
320/// use rustacuda::module::Module;
321/// use rustacuda::stream::*;
322/// use std::ffi::CString;
323///
324/// # fn main() -> Result<(), Box<dyn Error>> {
325///
326/// // Set up the context, load the module, and create a stream to run kernels in.
327/// let _ctx = rustacuda::quick_init()?;
328/// let ptx = CString::new(include_str!("../resources/add.ptx"))?;
329/// let module = Module::load_from_string(&ptx)?;
330/// let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
331///
332/// // Create buffers for data
333/// let mut in_x = DeviceBuffer::from_slice(&[1.0f32; 10])?;
334/// let mut in_y = DeviceBuffer::from_slice(&[2.0f32; 10])?;
335/// let mut out_1 = DeviceBuffer::from_slice(&[0.0f32; 10])?;
336/// let mut out_2 = DeviceBuffer::from_slice(&[0.0f32; 10])?;
337///
338/// // This kernel adds each element in `in_x` and `in_y` and writes the result into `out`.
339/// unsafe {
340/// // Launch the kernel with one block of one thread, no dynamic shared memory on `stream`.
341/// let result = launch!(module.sum<<<1, 1, 0, stream>>>(
342/// in_x.as_device_ptr(),
343/// in_y.as_device_ptr(),
344/// out_1.as_device_ptr(),
345/// out_1.len()
346/// ));
347/// // `launch!` returns an error in case anything went wrong with the launch itself, but
348/// // kernel launches are asynchronous so errors caused by the kernel (eg. invalid memory
349/// // access) will show up later at some other CUDA API call (probably at `synchronize()`
350/// // below).
351/// result?;
352///
353/// // Launch the kernel again using the `function` form:
354/// let function_name = CString::new("sum")?;
355/// let sum = module.get_function(&function_name)?;
356/// // Launch with 1x1x1 (1) blocks of 10x1x1 (10) threads, to show that you can use tuples to
357/// // configure grid and block size.
358/// let result = launch!(sum<<<(1, 1, 1), (10, 1, 1), 0, stream>>>(
359/// in_x.as_device_ptr(),
360/// in_y.as_device_ptr(),
361/// out_2.as_device_ptr(),
362/// out_2.len()
363/// ));
364/// result?;
365/// }
366///
367/// // Kernel launches are asynchronous, so we wait for the kernels to finish executing.
368/// stream.synchronize()?;
369///
370/// // Copy the results back to host memory
371/// let mut out_host = [0.0f32; 20];
372/// out_1.copy_to(&mut out_host[0..10])?;
373/// out_2.copy_to(&mut out_host[10..20])?;
374///
375/// for x in out_host.iter() {
376/// assert_eq!(3.0, *x);
377/// }
378/// # Ok(())
379/// # }
380/// ```
381///
382#[macro_export]
383macro_rules! launch {
384 ($module:ident . $function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
385 {
386 let name = std::ffi::CString::new(stringify!($function)).unwrap();
387 let function = $module.get_function(&name);
388 match function {
389 Ok(f) => launch!(f<<<$grid, $block, $shared, $stream>>>( $($arg),* ) ),
390 Err(e) => Err(e),
391 }
392 }
393 };
394 ($function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
395 {
396 fn assert_impl_devicecopy<T: $crate::memory::DeviceCopy>(_val: T) {}
397 if false {
398 $(
399 assert_impl_devicecopy($arg);
400 )*
401 };
402
403 $stream.launch(&$function, $grid, $block, $shared,
404 &[
405 $(
406 &$arg as *const _ as *mut ::std::ffi::c_void,
407 )*
408 ]
409 )
410 }
411 };
412}
413
414#[cfg(test)]
415mod test {
416 use super::*;
417 use crate::memory::CopyDestination;
418 use crate::memory::DeviceBuffer;
419 use crate::quick_init;
420 use crate::stream::{Stream, StreamFlags};
421 use std::error::Error;
422 use std::ffi::CString;
423
424 #[test]
425 fn test_launch() -> Result<(), Box<dyn Error>> {
426 let _context = quick_init();
427 let ptx_text = CString::new(include_str!("../resources/add.ptx"))?;
428 let module = Module::load_from_string(&ptx_text)?;
429
430 unsafe {
431 let mut in_x = DeviceBuffer::from_slice(&[2.0f32; 128])?;
432 let mut in_y = DeviceBuffer::from_slice(&[1.0f32; 128])?;
433 let mut out: DeviceBuffer<f32> = DeviceBuffer::uninitialized(128)?;
434
435 let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
436 launch!(module.sum<<<1, 128, 0, stream>>>(in_x.as_device_ptr(), in_y.as_device_ptr(), out.as_device_ptr(), out.len()))?;
437 stream.synchronize()?;
438
439 let mut out_host = [0f32; 128];
440 out.copy_to(&mut out_host[..])?;
441 for x in out_host.iter() {
442 assert_eq!(3, *x as u32);
443 }
444 }
445 Ok(())
446 }
447}