rustacuda/
function.rs

1//! Functions and types for working with CUDA kernels.
2
3use crate::context::{CacheConfig, SharedMemoryConfig};
4use crate::error::{CudaResult, ToResult};
5use crate::module::Module;
6use cuda_driver_sys::CUfunction;
7use std::marker::PhantomData;
8use std::mem::transmute;
9
10/// Dimensions of a grid, or the number of thread blocks in a kernel launch.
11///
12/// Each component of a `GridSize` must be at least 1. The maximum size depends on your device's
13/// compute capability, but maximums of `x = (2^31)-1, y = 65535, z = 65535` are common. Launching
14/// a kernel with a grid size greater than these limits will cause an error.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct GridSize {
17    /// Width of grid in blocks
18    pub x: u32,
19    /// Height of grid in blocks
20    pub y: u32,
21    /// Depth of grid in blocks
22    pub z: u32,
23}
24impl GridSize {
25    /// Create a one-dimensional grid of `x` blocks
26    #[inline]
27    pub fn x(x: u32) -> GridSize {
28        GridSize { x, y: 1, z: 1 }
29    }
30
31    /// Create a two-dimensional grid of `x * y` blocks
32    #[inline]
33    pub fn xy(x: u32, y: u32) -> GridSize {
34        GridSize { x, y, z: 1 }
35    }
36
37    /// Create a three-dimensional grid of `x * y * z` blocks
38    #[inline]
39    pub fn xyz(x: u32, y: u32, z: u32) -> GridSize {
40        GridSize { x, y, z }
41    }
42}
43impl From<u32> for GridSize {
44    fn from(x: u32) -> GridSize {
45        GridSize::x(x)
46    }
47}
48impl From<(u32, u32)> for GridSize {
49    fn from((x, y): (u32, u32)) -> GridSize {
50        GridSize::xy(x, y)
51    }
52}
53impl From<(u32, u32, u32)> for GridSize {
54    fn from((x, y, z): (u32, u32, u32)) -> GridSize {
55        GridSize::xyz(x, y, z)
56    }
57}
58impl<'a> From<&'a GridSize> for GridSize {
59    fn from(other: &GridSize) -> GridSize {
60        other.clone()
61    }
62}
63
64/// Dimensions of a thread block, or the number of threads in a block.
65///
66/// Each component of a `BlockSize` must be at least 1. The maximum size depends on your device's
67/// compute capability, but maximums of `x = 1024, y = 1024, z = 64` are common. In addition, the
68/// limit on total number of threads in a block (`x * y * z`) is also defined by the compute
69/// capability, typically 1024. Launching a kernel with a block size greater than these limits will
70/// cause an error.
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct BlockSize {
73    /// X dimension of each thread block
74    pub x: u32,
75    /// Y dimension of each thread block
76    pub y: u32,
77    /// Z dimension of each thread block
78    pub z: u32,
79}
80impl BlockSize {
81    /// Create a one-dimensional block of `x` threads
82    #[inline]
83    pub fn x(x: u32) -> BlockSize {
84        BlockSize { x, y: 1, z: 1 }
85    }
86
87    /// Create a two-dimensional block of `x * y` threads
88    #[inline]
89    pub fn xy(x: u32, y: u32) -> BlockSize {
90        BlockSize { x, y, z: 1 }
91    }
92
93    /// Create a three-dimensional block of `x * y * z` threads
94    #[inline]
95    pub fn xyz(x: u32, y: u32, z: u32) -> BlockSize {
96        BlockSize { x, y, z }
97    }
98}
99impl From<u32> for BlockSize {
100    fn from(x: u32) -> BlockSize {
101        BlockSize::x(x)
102    }
103}
104impl From<(u32, u32)> for BlockSize {
105    fn from((x, y): (u32, u32)) -> BlockSize {
106        BlockSize::xy(x, y)
107    }
108}
109impl From<(u32, u32, u32)> for BlockSize {
110    fn from((x, y, z): (u32, u32, u32)) -> BlockSize {
111        BlockSize::xyz(x, y, z)
112    }
113}
114impl<'a> From<&'a BlockSize> for BlockSize {
115    fn from(other: &BlockSize) -> BlockSize {
116        other.clone()
117    }
118}
119
120/// All supported function attributes for [Function::get_attribute](struct.Function.html#method.get_attribute)
121#[repr(u32)]
122#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
123pub enum FunctionAttribute {
124    /// The maximum number of threads per block, beyond which a launch would fail. This depends on
125    /// both the function and the device.
126    MaxThreadsPerBlock = 0,
127
128    /// The size in bytes of the statically-allocated shared memory required by this function.
129    SharedMemorySizeBytes = 1,
130
131    /// The size in bytes of the constant memory required by this function
132    ConstSizeBytes = 2,
133
134    /// The size in bytes of local memory used by each thread of this function
135    LocalSizeBytes = 3,
136
137    /// The number of registers used by each thread of this function
138    NumRegisters = 4,
139
140    /// The PTX virtual architecture version for which the function was compiled. This value is the
141    /// major PTX version * 10 + the minor PTX version, so version 1.3 would return the value 13.
142    PtxVersion = 5,
143
144    /// The binary architecture version for which the function was compiled. Encoded the same way as
145    /// PtxVersion.
146    BinaryVersion = 6,
147
148    /// The attribute to indicate whether the function has been compiled with user specified
149    /// option "-Xptxas --dlcm=ca" set.
150    CacheModeCa = 7,
151
152    #[doc(hidden)]
153    __Nonexhaustive = 8,
154}
155
156/// Handle to a global kernel function.
157#[derive(Debug)]
158pub struct Function<'a> {
159    inner: CUfunction,
160    module: PhantomData<&'a Module>,
161}
162impl<'a> Function<'a> {
163    pub(crate) fn new(inner: CUfunction, _module: &Module) -> Function {
164        Function {
165            inner,
166            module: PhantomData,
167        }
168    }
169
170    /// Returns information about a function.
171    ///
172    /// # Examples
173    ///
174    /// ```
175    /// # use rustacuda::*;
176    /// # use std::error::Error;
177    /// # fn main() -> Result<(), Box<dyn Error>> {
178    /// # let _ctx = quick_init()?;
179    /// # use rustacuda::module::Module;
180    /// # use std::ffi::CString;
181    /// # let ptx = CString::new(include_str!("../resources/add.ptx"))?;
182    /// # let module = Module::load_from_string(&ptx)?;
183    /// # let name = CString::new("sum")?;
184    /// use rustacuda::function::FunctionAttribute;
185    /// let function = module.get_function(&name)?;
186    /// let shared_memory = function.get_attribute(FunctionAttribute::SharedMemorySizeBytes)?;
187    /// println!("This function uses {} bytes of shared memory", shared_memory);
188    /// # Ok(())
189    /// # }
190    /// ```
191    pub fn get_attribute(&self, attr: FunctionAttribute) -> CudaResult<i32> {
192        unsafe {
193            let mut val = 0i32;
194            cuda_driver_sys::cuFuncGetAttribute(
195                &mut val as *mut i32,
196                // This should be safe, as the repr and values of FunctionAttribute should match.
197                ::std::mem::transmute(attr),
198                self.inner,
199            )
200            .to_result()?;
201            Ok(val)
202        }
203    }
204
205    /// Sets the preferred cache configuration for this function.
206    ///
207    /// On devices where L1 cache and shared memory use the same hardware resources, this sets the
208    /// preferred cache configuration for this function. This is only a preference. The
209    /// driver will use the requested configuration if possible, but is free to choose a different
210    /// configuration if required to execute the function. This setting will override the
211    /// context-wide setting.
212    ///
213    /// This setting does nothing on devices where the size of the L1 cache and shared memory are
214    /// fixed.
215    ///
216    /// # Example
217    ///
218    /// ```
219    /// # use rustacuda::*;
220    /// # use std::error::Error;
221    /// # fn main() -> Result<(), Box<dyn Error>> {
222    /// # let _ctx = quick_init()?;
223    /// # use rustacuda::module::Module;
224    /// # use std::ffi::CString;
225    /// # let ptx = CString::new(include_str!("../resources/add.ptx"))?;
226    /// # let module = Module::load_from_string(&ptx)?;
227    /// # let name = CString::new("sum")?;
228    /// use rustacuda::context::CacheConfig;
229    /// let mut function = module.get_function(&name)?;
230    /// function.set_cache_config(CacheConfig::PreferL1)?;
231    /// # Ok(())
232    /// # }
233    /// ```
234    pub fn set_cache_config(&mut self, config: CacheConfig) -> CudaResult<()> {
235        unsafe { cuda_driver_sys::cuFuncSetCacheConfig(self.inner, transmute(config)).to_result() }
236    }
237
238    /// Sets the preferred shared memory configuration for this function.
239    ///
240    /// On devices with configurable shared memory banks, this function will set this function's
241    /// shared memory bank size which is used for subsequent launches of this function. If not set,
242    /// the context-wide setting will be used instead.
243    ///
244    /// # Example
245    ///
246    /// ```
247    /// # use rustacuda::*;
248    /// # use std::error::Error;
249    /// # fn main() -> Result<(), Box<dyn Error>> {
250    /// # let _ctx = quick_init()?;
251    /// # use rustacuda::module::Module;
252    /// # use std::ffi::CString;
253    /// # let ptx = CString::new(include_str!("../resources/add.ptx"))?;
254    /// # let module = Module::load_from_string(&ptx)?;
255    /// # let name = CString::new("sum")?;
256    /// use rustacuda::context::SharedMemoryConfig;
257    /// let mut function = module.get_function(&name)?;
258    /// function.set_shared_memory_config(SharedMemoryConfig::EightByteBankSize)?;
259    /// # Ok(())
260    /// # }
261    /// ```
262    pub fn set_shared_memory_config(&mut self, cfg: SharedMemoryConfig) -> CudaResult<()> {
263        unsafe { cuda_driver_sys::cuFuncSetSharedMemConfig(self.inner, transmute(cfg)).to_result() }
264    }
265
266    pub(crate) fn to_inner(&self) -> CUfunction {
267        self.inner
268    }
269}
270
271/// Launch a kernel function asynchronously.
272///
273/// # Syntax:
274///
275/// The format of this macro is designed to resemble the triple-chevron syntax used to launch
276/// kernels in CUDA C. There are two forms available:
277///
278/// ```ignore
279/// let result = launch!(module.function_name<<<grid, block, shared_memory_size, stream>>>(parameter1, parameter2...));
280/// ```
281///
282/// This will load a kernel called `function_name` from the module `module` and launch it with
283/// the given grid/block size on the given stream. Unlike in CUDA C, the shared memory size and
284/// stream parameters are not optional. The shared memory size is a number of bytes per thread for
285/// dynamic shared memory (Note that this uses `extern __shared__ int x[]` in CUDA C, not the
286/// fixed-length arrays created by `__shared__ int x[64]`. This will usually be zero.).
287/// `stream` must be the name of a [`Stream`](stream/struct.Stream.html) value.
288/// `grid` can be any value which implements [`Into<GridSize>`](function/struct.GridSize.html) (such as
289/// `u32` values, tuples of up to three `u32` values, and GridSize structures) and likewise `block`
290/// can be any value that implements [`Into<BlockSize>`](function/struct.BlockSize.html).
291///
292/// NOTE: due to some limitations of Rust's macro system, `module` and `stream` must be local
293/// variable names. Paths or function calls will not work.
294///
295/// The second form is similar:
296///
297/// ```ignore
298/// let result = launch!(function<<<grid, block, shared_memory_size, stream>>>(parameter1, parameter2...));
299/// ```
300///
301/// In this variant, the `function` parameter must be a variable. Use this form to avoid looking up
302/// the kernel function for each call.
303///
304/// # Safety
305///
306/// Launching kernels must be done in an `unsafe` block. Calling a kernel is similar to calling a
307/// foreign-language function, as the kernel itself could be written in C or unsafe Rust. The kernel
308/// must accept the same number and type of parameters that are passed to the `launch!` macro. The
309/// kernel must not write invalid data (for example, invalid enums) into areas of memory that can
310/// be copied back to the host. The programmer must ensure that the host does not access device or
311/// unified memory that the kernel could write to until after calling `stream.synchronize()`.
312///
313/// # Examples
314///
315/// ```
316/// # #[macro_use]
317/// # use rustacuda::*;
318/// # use std::error::Error;
319/// use rustacuda::memory::*;
320/// use rustacuda::module::Module;
321/// use rustacuda::stream::*;
322/// use std::ffi::CString;
323///
324/// # fn main() -> Result<(), Box<dyn Error>> {
325///
326/// // Set up the context, load the module, and create a stream to run kernels in.
327/// let _ctx = rustacuda::quick_init()?;
328/// let ptx = CString::new(include_str!("../resources/add.ptx"))?;
329/// let module = Module::load_from_string(&ptx)?;
330/// let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
331///
332/// // Create buffers for data
333/// let mut in_x = DeviceBuffer::from_slice(&[1.0f32; 10])?;
334/// let mut in_y = DeviceBuffer::from_slice(&[2.0f32; 10])?;
335/// let mut out_1 = DeviceBuffer::from_slice(&[0.0f32; 10])?;
336/// let mut out_2 = DeviceBuffer::from_slice(&[0.0f32; 10])?;
337///
338/// // This kernel adds each element in `in_x` and `in_y` and writes the result into `out`.
339/// unsafe {
340///     // Launch the kernel with one block of one thread, no dynamic shared memory on `stream`.
341///     let result = launch!(module.sum<<<1, 1, 0, stream>>>(
342///         in_x.as_device_ptr(),
343///         in_y.as_device_ptr(),
344///         out_1.as_device_ptr(),
345///         out_1.len()
346///     ));
347///     // `launch!` returns an error in case anything went wrong with the launch itself, but
348///     // kernel launches are asynchronous so errors caused by the kernel (eg. invalid memory
349///     // access) will show up later at some other CUDA API call (probably at `synchronize()`
350///     // below).
351///     result?;
352///
353///     // Launch the kernel again using the `function` form:
354///     let function_name = CString::new("sum")?;
355///     let sum = module.get_function(&function_name)?;
356///     // Launch with 1x1x1 (1) blocks of 10x1x1 (10) threads, to show that you can use tuples to
357///     // configure grid and block size.
358///     let result = launch!(sum<<<(1, 1, 1), (10, 1, 1), 0, stream>>>(
359///         in_x.as_device_ptr(),
360///         in_y.as_device_ptr(),
361///         out_2.as_device_ptr(),
362///         out_2.len()
363///     ));
364///     result?;
365/// }
366///
367/// // Kernel launches are asynchronous, so we wait for the kernels to finish executing.
368/// stream.synchronize()?;
369///
370/// // Copy the results back to host memory
371/// let mut out_host = [0.0f32; 20];
372/// out_1.copy_to(&mut out_host[0..10])?;
373/// out_2.copy_to(&mut out_host[10..20])?;
374///
375/// for x in out_host.iter() {
376///     assert_eq!(3.0, *x);
377/// }
378/// # Ok(())
379/// # }
380/// ```
381///
382#[macro_export]
383macro_rules! launch {
384    ($module:ident . $function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
385        {
386            let name = std::ffi::CString::new(stringify!($function)).unwrap();
387            let function = $module.get_function(&name);
388            match function {
389                Ok(f) => launch!(f<<<$grid, $block, $shared, $stream>>>( $($arg),* ) ),
390                Err(e) => Err(e),
391            }
392        }
393    };
394    ($function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* )) => {
395        {
396            fn assert_impl_devicecopy<T: $crate::memory::DeviceCopy>(_val: T) {}
397            if false {
398                $(
399                    assert_impl_devicecopy($arg);
400                )*
401            };
402
403            $stream.launch(&$function, $grid, $block, $shared,
404                &[
405                    $(
406                        &$arg as *const _ as *mut ::std::ffi::c_void,
407                    )*
408                ]
409            )
410        }
411    };
412}
413
414#[cfg(test)]
415mod test {
416    use super::*;
417    use crate::memory::CopyDestination;
418    use crate::memory::DeviceBuffer;
419    use crate::quick_init;
420    use crate::stream::{Stream, StreamFlags};
421    use std::error::Error;
422    use std::ffi::CString;
423
424    #[test]
425    fn test_launch() -> Result<(), Box<dyn Error>> {
426        let _context = quick_init();
427        let ptx_text = CString::new(include_str!("../resources/add.ptx"))?;
428        let module = Module::load_from_string(&ptx_text)?;
429
430        unsafe {
431            let mut in_x = DeviceBuffer::from_slice(&[2.0f32; 128])?;
432            let mut in_y = DeviceBuffer::from_slice(&[1.0f32; 128])?;
433            let mut out: DeviceBuffer<f32> = DeviceBuffer::uninitialized(128)?;
434
435            let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
436            launch!(module.sum<<<1, 128, 0, stream>>>(in_x.as_device_ptr(), in_y.as_device_ptr(), out.as_device_ptr(), out.len()))?;
437            stream.synchronize()?;
438
439            let mut out_host = [0f32; 128];
440            out.copy_to(&mut out_host[..])?;
441            for x in out_host.iter() {
442                assert_eq!(3, *x as u32);
443            }
444        }
445        Ok(())
446    }
447}