1use std::{ffi::CStr, ptr};
2
3use singe_cuda_sys::driver;
4
5use crate::{
6 context::Context,
7 error::{Error, Result},
8 try_ffi,
9 types::FunctionAttribute,
10};
11
12pub trait KernelHandle {
17 type RawHandle: Copy;
19
20 unsafe fn raw_name(
27 raw: Self::RawHandle,
28 name: *mut *const i8,
29 device_id: i32,
30 ) -> driver::CUresult;
31
32 unsafe fn raw_attribute(
39 value: *mut i32,
40 attribute: driver::CUfunction_attribute,
41 raw: Self::RawHandle,
42 device_id: i32,
43 ) -> driver::CUresult;
44
45 unsafe fn set_attribute(
52 raw: Self::RawHandle,
53 attribute: driver::CUfunction_attribute,
54 value: i32,
55 device_id: i32,
56 ) -> driver::CUresult;
57}
58
59#[derive(Debug)]
60pub struct ModuleKernelHandle;
61
62impl KernelHandle for ModuleKernelHandle {
63 type RawHandle = driver::CUfunction;
64
65 unsafe fn raw_name(
66 raw: Self::RawHandle,
67 name: *mut *const i8,
68 _device_id: i32,
69 ) -> driver::CUresult {
70 unsafe { driver::cuFuncGetName(name, raw) }
71 }
72
73 unsafe fn raw_attribute(
74 value: *mut i32,
75 attribute: driver::CUfunction_attribute,
76 raw: Self::RawHandle,
77 _device_id: i32,
78 ) -> driver::CUresult {
79 unsafe { driver::cuFuncGetAttribute(value, attribute, raw) }
80 }
81
82 unsafe fn set_attribute(
83 raw: Self::RawHandle,
84 attribute: driver::CUfunction_attribute,
85 value: i32,
86 _device_id: i32,
87 ) -> driver::CUresult {
88 unsafe { driver::cuFuncSetAttribute(raw, attribute, value) }
89 }
90}
91
92#[derive(Debug)]
93pub struct LibraryKernelHandle;
94
95impl KernelHandle for LibraryKernelHandle {
96 type RawHandle = driver::CUkernel;
97
98 unsafe fn raw_name(
99 raw: Self::RawHandle,
100 name: *mut *const i8,
101 _device_id: i32,
102 ) -> driver::CUresult {
103 unsafe { driver::cuKernelGetName(name, raw) }
104 }
105
106 unsafe fn raw_attribute(
107 value: *mut i32,
108 attribute: driver::CUfunction_attribute,
109 raw: Self::RawHandle,
110 device_id: i32,
111 ) -> driver::CUresult {
112 unsafe { driver::cuKernelGetAttribute(value, attribute, raw, device_id) }
113 }
114
115 unsafe fn set_attribute(
116 raw: Self::RawHandle,
117 attribute: driver::CUfunction_attribute,
118 value: i32,
119 device_id: i32,
120 ) -> driver::CUresult {
121 unsafe { driver::cuKernelSetAttribute(attribute, value, raw, device_id) }
122 }
123}
124
125pub fn name<H: KernelHandle>(ctx: &Context, raw: H::RawHandle) -> Result<String> {
126 ctx.bind()?;
127 let mut name = ptr::null();
128 unsafe {
129 try_ffi!(H::raw_name(raw, &raw mut name, ctx.device().id()))?;
130 if name.is_null() {
131 return Err(Error::NullHandle);
132 }
133 Ok(CStr::from_ptr(name).to_string_lossy().into_owned())
134 }
135}
136
137pub fn attribute<H: KernelHandle>(
138 ctx: &Context,
139 raw: H::RawHandle,
140 attribute: FunctionAttribute,
141) -> Result<i32> {
142 ctx.bind()?;
143 let mut value = 0;
144 unsafe {
145 try_ffi!(H::raw_attribute(
146 &raw mut value,
147 attribute.into(),
148 raw,
149 ctx.device().id(),
150 ))?;
151 }
152 Ok(value)
153}
154
155pub fn set_attribute<H: KernelHandle>(
156 ctx: &Context,
157 raw: H::RawHandle,
158 attribute: FunctionAttribute,
159 value: i32,
160) -> Result<()> {
161 ctx.bind()?;
162 unsafe {
163 try_ffi!(H::set_attribute(
164 raw,
165 attribute.into(),
166 value,
167 ctx.device().id(),
168 ))?;
169 }
170 Ok(())
171}