1use std::ffi::{c_char, c_void};
2
3#[derive(Clone, Copy, PartialEq, Eq)]
4#[repr(C)]
5pub struct CudaRustError {
6 pub message: *const c_char,
7}
8
9extern "C" {
10 pub static CUDA_SUCCESS_CSL: CudaRustError;
11
12 pub static CUDA_OUT_OF_MEMORY: CudaRustError;
13
14 pub static CUDA_ERROR_NOT_READY_SLOP: CudaRustError;
15
16 pub fn cuda_malloc(ptr: *mut *mut c_void, count: usize) -> CudaRustError;
17
18 pub fn cuda_free(ptr: *const c_void) -> CudaRustError;
19
20 pub fn cuda_mem_get_info(free: *mut usize, total: *mut usize) -> CudaRustError;
21
22 pub fn cuda_malloc_host(ptr: *mut *mut c_void, count: usize) -> CudaRustError;
23 pub fn cuda_host_register(ptr: *const c_void, count: usize) -> CudaRustError;
24 pub fn cuda_free_host(ptr: *const c_void) -> CudaRustError;
25 pub fn cuda_host_unregister(ptr: *const c_void) -> CudaRustError;
26
27 pub fn cuda_mem_set(dst: *mut c_void, value: u8, size: usize) -> CudaRustError;
28
29 pub fn cuda_mem_copy_host_to_device(
30 dst: *mut c_void,
31 src: *const c_void,
32 count: usize,
33 ) -> CudaRustError;
34
35 pub fn cuda_mem_copy_device_to_host(
36 dst: *mut c_void,
37 src: *const c_void,
38 count: usize,
39 ) -> CudaRustError;
40
41 pub fn cuda_mem_copy_device_to_device(
42 dst: *const c_void,
43 src: *const c_void,
44 count: usize,
45 ) -> CudaRustError;
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
49#[repr(transparent)]
50pub struct CudaStreamHandle(pub *mut c_void);
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53#[repr(transparent)]
54pub struct CudaEventHandle(pub *mut c_void);
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57#[repr(C)]
58pub struct Dim3 {
59 pub x: u32,
60 pub y: u32,
61 pub z: u32,
62}
63
64#[repr(transparent)]
65pub struct KernelPtr(pub *const c_void);
66
67#[repr(transparent)]
68pub struct CudaMemPool(pub *mut c_void);
69
70#[repr(transparent)]
71pub struct CudaDevice(pub i32);
72
73extern "C" {
74
75 pub static DEFAULT_STREAM: CudaStreamHandle;
76
77 pub fn cuda_device_synchronize() -> CudaRustError;
78 pub fn cuda_event_create(event: *mut CudaEventHandle) -> CudaRustError;
79 pub fn cuda_event_destroy(event: CudaEventHandle) -> CudaRustError;
80 pub fn cuda_event_record(event: CudaEventHandle, stream: CudaStreamHandle) -> CudaRustError;
81 pub fn cuda_event_synchronize(event: CudaEventHandle) -> CudaRustError;
82 pub fn cuda_event_elapsed_time(
83 ms: *mut f32,
84 start: CudaEventHandle,
85 end: CudaEventHandle,
86 ) -> CudaRustError;
87
88 pub fn cuda_stream_create(stream: *mut CudaStreamHandle) -> CudaRustError;
89 pub fn cuda_stream_destroy(stream: CudaStreamHandle) -> CudaRustError;
90 pub fn cuda_stream_synchronize(stream: CudaStreamHandle) -> CudaRustError;
91
92 pub fn cuda_stream_wait_event(
93 stream: CudaStreamHandle,
94 event: CudaEventHandle,
95 ) -> CudaRustError;
96
97 pub fn cuda_malloc_async(
100 devPtr: *mut *mut c_void,
101 size: usize,
102 stream: CudaStreamHandle,
103 ) -> CudaRustError;
104
105 pub fn cuda_mem_set_async(
106 dst: *mut c_void,
107 value: u8,
108 size: usize,
109 stream: CudaStreamHandle,
110 ) -> CudaRustError;
111
112 pub fn cuda_free_async(devPtr: *mut c_void, stream: CudaStreamHandle) -> CudaRustError;
113
114 pub fn cuda_mem_copy_device_to_device_async(
115 dst: *mut c_void,
116 src: *const c_void,
117 count: usize,
118 stream: CudaStreamHandle,
119 ) -> CudaRustError;
120 pub fn cuda_mem_copy_host_to_device_async(
121 dst: *mut c_void,
122 src: *const c_void,
123 count: usize,
124 stream: CudaStreamHandle,
125 ) -> CudaRustError;
126 pub fn cuda_mem_copy_device_to_host_async(
127 dst: *mut c_void,
128 src: *const c_void,
129 count: usize,
130 stream: CudaStreamHandle,
131 ) -> CudaRustError;
132 pub fn cuda_mem_copy_host_to_host_async(
133 dst: *mut c_void,
134 src: *const c_void,
135 count: usize,
136 stream: CudaStreamHandle,
137 ) -> CudaRustError;
138
139 pub fn cuda_stream_query(stream: CudaStreamHandle) -> CudaRustError;
140
141 pub fn cuda_event_query(event: CudaEventHandle) -> CudaRustError;
142
143 pub fn cuda_launch_host_function(
144 stream: CudaStreamHandle,
145 host_fn: Option<unsafe extern "C" fn(*mut c_void)>,
146 data: *const c_void,
147 ) -> CudaRustError;
148
149 pub fn cuda_launch_kernel(
150 kernel: KernelPtr,
151 grid: Dim3,
152 block: Dim3,
153 args: *mut *mut c_void,
154 shared_mem: usize,
155 stream: CudaStreamHandle,
156 ) -> CudaRustError;
157
158 pub fn cuda_device_get_default_mem_pool(
159 memPool: *mut CudaMemPool,
160 device: CudaDevice,
161 ) -> CudaRustError;
162
163 pub fn cuda_device_get_mem_pool(memPool: *mut CudaMemPool, device: CudaDevice)
164 -> CudaRustError;
165 pub fn cuda_mem_pool_set_release_threshold(
166 memPool: CudaMemPool,
167 threshold: u64,
168 ) -> CudaRustError;
169}
170
171#[derive(Debug, Clone, Copy)]
172#[repr(transparent)]
173pub struct NvtxRangeId(u64);
174
175extern "C" {
176 pub fn nvtx_range_start(name: *const c_char) -> NvtxRangeId;
177
178 pub fn nvtx_range_end(domain: NvtxRangeId);
179}
180
181impl Dim3 {
182 pub fn new(x: u32, y: u32, z: u32) -> Self {
183 Self { x, y, z }
184 }
185
186 pub fn x(num_elements: u32) -> Self {
187 Self { x: num_elements, y: 1, z: 1 }
188 }
189}
190
191impl From<u32> for Dim3 {
192 fn from(x: u32) -> Self {
193 Self { x, y: 1, z: 1 }
194 }
195}
196
197impl From<u64> for Dim3 {
198 fn from(x: u64) -> Self {
199 Self { x: x as u32, y: 1, z: 1 }
200 }
201}
202
203impl From<i32> for Dim3 {
204 fn from(x: i32) -> Self {
205 Self { x: x as u32, y: 1, z: 1 }
206 }
207}
208
209impl From<i64> for Dim3 {
210 fn from(x: i64) -> Self {
211 Self { x: x as u32, y: 1, z: 1 }
212 }
213}
214
215impl From<usize> for Dim3 {
216 fn from(x: usize) -> Self {
217 Self { x: x as u32, y: 1, z: 1 }
218 }
219}
220
221impl From<(u32, u32, u32)> for Dim3 {
222 fn from((x, y, z): (u32, u32, u32)) -> Self {
223 Self { x, y, z }
224 }
225}
226
227impl From<(u64, u64, u64)> for Dim3 {
228 fn from((x, y, z): (u64, u64, u64)) -> Self {
229 Self { x: x as u32, y: y as u32, z: z as u32 }
230 }
231}
232
233impl From<(usize, usize, usize)> for Dim3 {
234 fn from((x, y, z): (usize, usize, usize)) -> Self {
235 Self { x: x as u32, y: y as u32, z: z as u32 }
236 }
237}