1pub(crate) mod utils;
4
5use std::collections::HashMap;
6use std::hash::{Hash, Hasher};
7use std::ptr;
8
9use opencl3::command_queue::CommandQueue;
10use opencl3::context::Context;
11use opencl3::error_codes::ClError;
12use opencl3::kernel::ExecuteKernel;
13use opencl3::memory::CL_MEM_READ_WRITE;
14use opencl3::types::CL_BLOCKING;
15
16use log::debug;
17
18use crate::device::{DeviceUuid, PciId, Vendor};
19use crate::error::{GPUError, GPUResult};
20use crate::LocalBuffer;
21
22#[allow(non_camel_case_types)]
24pub type cl_device_id = opencl3::types::cl_device_id;
25
26#[derive(Debug)]
28pub struct Buffer<T> {
29 buffer: opencl3::memory::Buffer<u8>,
30 length: usize,
32 _phantom: std::marker::PhantomData<T>,
33}
34
35#[derive(Debug, Clone)]
37pub struct Device {
38 vendor: Vendor,
39 name: String,
40 memory: u64,
42 compute_units: u32,
44 compute_capability: Option<(u32, u32)>,
46 pci_id: PciId,
47 uuid: Option<DeviceUuid>,
48 device: opencl3::device::Device,
49}
50
51impl Hash for Device {
52 fn hash<H: Hasher>(&self, state: &mut H) {
53 self.vendor.hash(state);
54 self.name.hash(state);
55 self.memory.hash(state);
56 self.pci_id.hash(state);
57 self.uuid.hash(state);
58 }
59}
60
61impl PartialEq for Device {
62 fn eq(&self, other: &Self) -> bool {
63 self.vendor == other.vendor
64 && self.name == other.name
65 && self.memory == other.memory
66 && self.pci_id == other.pci_id
67 && self.uuid == other.uuid
68 }
69}
70
71impl Eq for Device {}
72
73impl Device {
74 pub fn vendor(&self) -> Vendor {
76 self.vendor
77 }
78
79 pub fn name(&self) -> String {
81 self.name.clone()
82 }
83
84 pub fn memory(&self) -> u64 {
86 self.memory
87 }
88
89 pub fn compute_units(&self) -> u32 {
91 self.compute_units
92 }
93
94 pub fn compute_capability(&self) -> Option<(u32, u32)> {
97 self.compute_capability
98 }
99
100 pub fn pci_id(&self) -> PciId {
102 self.pci_id
103 }
104
105 pub fn uuid(&self) -> Option<DeviceUuid> {
108 self.uuid
109 }
110
111 pub fn cl_device_id(&self) -> cl_device_id {
116 self.device.id()
117 }
118}
119
120#[allow(rustdoc::broken_intra_doc_links)]
125pub struct Program {
126 device_name: String,
127 queue: CommandQueue,
128 context: Context,
129 kernels_by_name: HashMap<String, opencl3::kernel::Kernel>,
130}
131
132impl Program {
133 pub fn device_name(&self) -> &str {
135 &self.device_name
136 }
137
138 pub fn from_opencl(device: &Device, src: &str) -> GPUResult<Program> {
140 debug!("Creating OpenCL program from source.");
141 let cached = utils::cache_path(device, src)?;
142 if std::path::Path::exists(&cached) {
143 let bin = std::fs::read(cached)?;
144 Program::from_binary(device, bin)
145 } else {
146 let context = Context::from_device(&device.device)?;
147 debug!(
148 "Building kernel ({}) from sourceā¦",
149 cached.to_string_lossy()
150 );
151 let mut program = opencl3::program::Program::create_from_source(&context, src)?;
152 if let Err(build_error) = program.build(context.devices(), "") {
153 let log = program.get_build_log(context.devices()[0])?;
154 return Err(GPUError::Opencl3(build_error, Some(log)));
155 }
156 debug!(
157 "Building kernel ({}) from source: done.",
158 cached.to_string_lossy()
159 );
160 let queue = CommandQueue::create_default(&context, 0)?;
161 let kernels = opencl3::kernel::create_program_kernels(&program)?;
162 let kernels_by_name = kernels
163 .into_iter()
164 .map(|kernel| {
165 let name = kernel.function_name()?;
166 Ok((name, kernel))
167 })
168 .collect::<Result<_, ClError>>()?;
169 let prog = Program {
170 device_name: device.name(),
171 queue,
172 context,
173 kernels_by_name,
174 };
175 let binaries = program
176 .get_binaries()
177 .map_err(GPUError::ProgramInfoNotAvailable)?;
178 std::fs::write(cached, binaries[0].clone())?;
179 Ok(prog)
180 }
181 }
182
183 pub fn from_binary(device: &Device, bin: Vec<u8>) -> GPUResult<Program> {
185 debug!("Creating OpenCL program from binary.");
186 let context = Context::from_device(&device.device)?;
187 let bins = vec![&bin[..]];
188 let mut program = unsafe {
189 opencl3::program::Program::create_from_binary(&context, context.devices(), &bins)
190 }?;
191 if let Err(build_error) = program.build(context.devices(), "") {
192 let log = program.get_build_log(context.devices()[0])?;
193 return Err(GPUError::Opencl3(build_error, Some(log)));
194 }
195 let queue = CommandQueue::create_default(&context, 0)?;
196 let kernels = opencl3::kernel::create_program_kernels(&program)?;
197 let kernels_by_name = kernels
198 .into_iter()
199 .map(|kernel| {
200 let name = kernel.function_name()?;
201 Ok((name, kernel))
202 })
203 .collect::<Result<_, ClError>>()?;
204 Ok(Program {
205 device_name: device.name(),
206 queue,
207 context,
208 kernels_by_name,
209 })
210 }
211
212 pub unsafe fn create_buffer<T>(&self, length: usize) -> GPUResult<Buffer<T>> {
225 assert!(length > 0);
226 let mut buff = opencl3::memory::Buffer::create(
227 &self.context,
228 CL_MEM_READ_WRITE,
229 length * std::mem::size_of::<T>(),
232 ptr::null_mut(),
233 )?;
234
235 self.queue
237 .enqueue_write_buffer(&mut buff, opencl3::types::CL_BLOCKING, 0, &[0u8], &[])?;
238
239 Ok(Buffer::<T> {
240 buffer: buff,
241 length,
242 _phantom: std::marker::PhantomData,
243 })
244 }
245
246 pub fn create_buffer_from_slice<T>(&self, slice: &[T]) -> GPUResult<Buffer<T>> {
248 let length = slice.len();
249 let bytes_len = length * std::mem::size_of::<T>();
251
252 let mut buffer = unsafe {
253 opencl3::memory::Buffer::create(
254 &self.context,
255 CL_MEM_READ_WRITE,
256 bytes_len,
257 ptr::null_mut(),
258 )?
259 };
260 let bytes = unsafe {
262 std::slice::from_raw_parts(slice.as_ptr() as *const T as *const u8, bytes_len)
263 };
264 unsafe {
266 self.queue
267 .enqueue_write_buffer(&mut buffer, CL_BLOCKING, 0, &[0u8], &[])?;
268 self.queue
269 .enqueue_write_buffer(&mut buffer, CL_BLOCKING, 0, bytes, &[])?;
270 };
271
272 Ok(Buffer::<T> {
273 buffer,
274 length,
275 _phantom: std::marker::PhantomData,
276 })
277 }
278
279 pub fn create_kernel(
286 &self,
287 name: &str,
288 global_work_size: usize,
289 local_work_size: usize,
290 ) -> GPUResult<Kernel> {
291 let kernel = self
292 .kernels_by_name
293 .get(name)
294 .ok_or_else(|| GPUError::KernelNotFound(name.to_string()))?;
295 let mut builder = ExecuteKernel::new(kernel);
296 builder.set_global_work_size(global_work_size * local_work_size);
297 builder.set_local_work_size(local_work_size);
298 Ok(Kernel {
299 builder,
300 queue: &self.queue,
301 num_local_buffers: 0,
302 })
303 }
304
305 pub fn write_from_buffer<T>(
307 &self,
308 buffer: &mut Buffer<T>,
311 data: &[T],
312 ) -> GPUResult<()> {
313 assert!(data.len() <= buffer.length, "Buffer is too small");
314
315 let bytes = unsafe {
317 std::slice::from_raw_parts(
318 data.as_ptr() as *const T as *const u8,
319 data.len() * std::mem::size_of::<T>(),
320 )
321 };
322 unsafe {
323 self.queue
324 .enqueue_write_buffer(&mut buffer.buffer, CL_BLOCKING, 0, bytes, &[])?;
325 }
326 Ok(())
327 }
328
329 pub fn read_into_buffer<T>(&self, buffer: &Buffer<T>, data: &mut [T]) -> GPUResult<()> {
331 assert!(data.len() <= buffer.length, "Buffer is too small");
332
333 let bytes = unsafe {
335 std::slice::from_raw_parts_mut(
336 data.as_mut_ptr() as *mut T as *mut u8,
337 data.len() * std::mem::size_of::<T>(),
338 )
339 };
340 unsafe {
341 self.queue
342 .enqueue_read_buffer(&buffer.buffer, CL_BLOCKING, 0, bytes, &[])?;
343 };
344 Ok(())
345 }
346
347 pub fn run<F, R, E, A>(&self, fun: F, arg: A) -> Result<R, E>
352 where
353 F: FnOnce(&Self, A) -> Result<R, E>,
354 E: From<GPUError>,
355 {
356 fun(self, arg)
357 }
358}
359
360pub trait KernelArgument {
366 fn push(&self, kernel: &mut Kernel);
368}
369
370impl<T> KernelArgument for Buffer<T> {
371 fn push(&self, kernel: &mut Kernel) {
372 unsafe {
373 kernel.builder.set_arg(&self.buffer);
374 }
375 }
376}
377
378impl KernelArgument for i32 {
379 fn push(&self, kernel: &mut Kernel) {
380 unsafe {
381 kernel.builder.set_arg(self);
382 }
383 }
384}
385
386impl KernelArgument for u32 {
387 fn push(&self, kernel: &mut Kernel) {
388 unsafe {
389 kernel.builder.set_arg(self);
390 }
391 }
392}
393
394impl<T> KernelArgument for LocalBuffer<T> {
395 fn push(&self, kernel: &mut Kernel) {
396 unsafe {
397 kernel
398 .builder
399 .set_arg_local_buffer(self.length * std::mem::size_of::<T>());
400 }
401 kernel.num_local_buffers += 1;
402 }
403}
404
405#[derive(Debug)]
407pub struct Kernel<'a> {
408 pub builder: ExecuteKernel<'a>,
410 queue: &'a CommandQueue,
411 num_local_buffers: u8,
414}
415
416impl<'a> Kernel<'a> {
417 pub fn arg<T: KernelArgument>(mut self, t: &'a T) -> Self {
438 t.push(&mut self);
439 self
440 }
441
442 pub fn run(mut self) -> GPUResult<()> {
444 if self.num_local_buffers > 1 {
445 return Err(GPUError::Generic(
446 "There cannot be more than one `LocalBuffer`.".to_string(),
447 ));
448 }
449 unsafe {
450 self.builder.enqueue_nd_range(self.queue)?;
451 }
452 Ok(())
453 }
454}