1use std::{
4 marker::PhantomData,
5 mem::size_of,
6 ops::{Bound, RangeBounds},
7 ptr::NonNull,
8};
9
10use crate::{
11 error::{Error, Result},
12 memory::DeviceMemory,
13 types::{Complex32, Complex64, bf16, f4e2m1, f6e2m3, f6e3m2, f8e4m3, f8e5m2, f8ue8m0, f16},
14};
15
16pub unsafe trait DeviceRepr: Copy + 'static {}
24
25pub unsafe trait ZeroableDeviceRepr: DeviceRepr {}
31
32macro_rules! impl_device_repr {
33 ($($ty:ty),* $(,)?) => {
34 $(
35 unsafe impl DeviceRepr for $ty {}
36 unsafe impl ZeroableDeviceRepr for $ty {}
37 )*
38 };
39}
40
41impl_device_repr!(
42 bool, u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64, f16, bf16,
43 Complex32, Complex64, f8e4m3, f8e5m2, f8ue8m0, f6e2m3, f6e3m2, f4e2m1,
44);
45
46pub trait DeviceSlice<T: DeviceRepr> {
48 fn as_device_ptr(&self) -> *const T;
49
50 fn len(&self) -> usize;
51
52 fn is_empty(&self) -> bool {
53 self.len() == 0
54 }
55
56 fn byte_len(&self) -> Result<usize> {
57 self.len()
58 .checked_mul(size_of::<T>())
59 .ok_or(Error::InvalidMemoryAllocationRequest)
60 }
61}
62
63pub trait DeviceSliceMut<T: DeviceRepr>: DeviceSlice<T> {
65 fn as_device_mut_ptr(&mut self) -> *mut T;
66}
67
68pub trait DeviceBuffer<T: DeviceRepr>: DeviceSlice<T> {}
70
71impl<T, B> DeviceBuffer<T> for B
72where
73 T: DeviceRepr,
74 B: DeviceSlice<T> + ?Sized,
75{
76}
77
78pub trait DeviceBufferMut<T: DeviceRepr>: DeviceBuffer<T> + DeviceSliceMut<T> {}
80
81impl<T, B> DeviceBufferMut<T> for B
82where
83 T: DeviceRepr,
84 B: DeviceBuffer<T> + DeviceSliceMut<T> + ?Sized,
85{
86}
87
88pub trait HostSlice<T: DeviceRepr> {
90 fn as_host_ptr(&self) -> *const T;
91
92 fn len(&self) -> usize;
93
94 fn is_empty(&self) -> bool {
95 self.len() == 0
96 }
97}
98
99pub trait HostSliceMut<T: DeviceRepr>: HostSlice<T> {
101 fn as_host_mut_ptr(&mut self) -> *mut T;
102}
103
104pub trait HostBuffer<T: DeviceRepr>: HostSlice<T> {}
106
107impl<T, B> HostBuffer<T> for B
108where
109 T: DeviceRepr,
110 B: HostSlice<T> + ?Sized,
111{
112}
113
114pub trait HostBufferMut<T: DeviceRepr>: HostBuffer<T> + HostSliceMut<T> {}
116
117impl<T, B> HostBufferMut<T> for B
118where
119 T: DeviceRepr,
120 B: HostBuffer<T> + HostSliceMut<T> + ?Sized,
121{
122}
123
124pub trait ByteBuffer {
126 fn as_byte_ptr(&self) -> *const u8;
127
128 fn byte_len(&self) -> usize;
129
130 fn is_empty(&self) -> bool {
131 self.byte_len() == 0
132 }
133}
134
135pub trait ByteBufferMut: ByteBuffer {
137 fn as_byte_mut_ptr(&mut self) -> *mut u8;
138}
139
140impl<B> ByteBuffer for B
141where
142 B: DeviceSlice<u8> + ?Sized,
143{
144 fn as_byte_ptr(&self) -> *const u8 {
145 self.as_device_ptr()
146 }
147
148 fn byte_len(&self) -> usize {
149 self.len()
150 }
151}
152
153impl<B> ByteBufferMut for B
154where
155 B: DeviceSliceMut<u8> + ?Sized,
156{
157 fn as_byte_mut_ptr(&mut self) -> *mut u8 {
158 self.as_device_mut_ptr()
159 }
160}
161
162#[derive(Debug, Clone, Copy)]
163pub struct DeviceView<'a, T: DeviceRepr> {
171 ptr: *const T,
172 length: usize,
173 _t: PhantomData<&'a T>,
174}
175
176#[derive(Debug)]
177pub struct DeviceViewMut<'a, T: DeviceRepr> {
182 ptr: *mut T,
183 length: usize,
184 _t: PhantomData<&'a mut T>,
185}
186
187impl<'a, T: DeviceRepr> DeviceView<'a, T> {
188 pub const unsafe fn from_raw_parts(ptr: *const T, length: usize) -> Self {
199 let ptr = if length == 0 {
200 NonNull::<T>::dangling().as_ptr() as *const T
201 } else {
202 ptr
203 };
204 Self {
205 ptr,
206 length,
207 _t: PhantomData,
208 }
209 }
210
211 pub fn from_memory(memory: &'a DeviceMemory<T>) -> Self {
212 Self {
213 ptr: memory.as_ptr(),
214 length: memory.len(),
215 _t: PhantomData,
216 }
217 }
218
219 pub const fn as_ptr(&self) -> *const T {
220 self.ptr
221 }
222
223 pub const fn len(&self) -> usize {
224 self.length
225 }
226
227 pub const fn is_empty(&self) -> bool {
228 self.length == 0
229 }
230
231 pub fn slice<R: RangeBounds<usize>>(self, range: R) -> Result<Self> {
232 let (start, end) = bounds_to_range(range, self.length)?;
233 let ptr = self.ptr.wrapping_add(start);
236 Ok(Self {
237 ptr,
238 length: end - start,
239 _t: PhantomData,
240 })
241 }
242}
243
244impl<'a, T: DeviceRepr> DeviceViewMut<'a, T> {
245 pub const unsafe fn from_raw_parts(ptr: *mut T, length: usize) -> Self {
256 let ptr = if length == 0 {
257 NonNull::<T>::dangling().as_ptr()
258 } else {
259 ptr
260 };
261 Self {
262 ptr,
263 length,
264 _t: PhantomData,
265 }
266 }
267
268 pub fn from_memory(memory: &'a mut DeviceMemory<T>) -> Self {
269 Self {
270 ptr: memory.as_mut_ptr(),
271 length: memory.len(),
272 _t: PhantomData,
273 }
274 }
275
276 pub const fn as_ptr(&self) -> *const T {
277 self.ptr
278 }
279
280 pub const fn as_mut_ptr(&mut self) -> *mut T {
281 self.ptr
282 }
283
284 pub const fn len(&self) -> usize {
285 self.length
286 }
287
288 pub const fn is_empty(&self) -> bool {
289 self.length == 0
290 }
291
292 pub fn as_view(&self) -> DeviceView<'_, T> {
293 DeviceView {
294 ptr: self.ptr,
295 length: self.length,
296 _t: PhantomData,
297 }
298 }
299
300 pub fn slice<R: RangeBounds<usize>>(&self, range: R) -> Result<DeviceView<'_, T>> {
301 self.as_view().slice(range)
302 }
303
304 pub fn slice_mut<R: RangeBounds<usize>>(&mut self, range: R) -> Result<DeviceViewMut<'_, T>> {
305 let (start, end) = bounds_to_range(range, self.length)?;
306 let ptr = self.ptr.wrapping_add(start);
309 Ok(DeviceViewMut {
310 ptr,
311 length: end - start,
312 _t: PhantomData,
313 })
314 }
315
316 pub fn split_at_mut(
317 &mut self,
318 mid: usize,
319 ) -> Result<(DeviceViewMut<'_, T>, DeviceViewMut<'_, T>)> {
320 if mid > self.length {
321 return Err(Error::InvalidMemoryAccess);
322 }
323
324 let right = self.ptr.wrapping_add(mid);
326 Ok((
327 DeviceViewMut {
328 ptr: self.ptr,
329 length: mid,
330 _t: PhantomData,
331 },
332 DeviceViewMut {
333 ptr: right,
334 length: self.length - mid,
335 _t: PhantomData,
336 },
337 ))
338 }
339}
340
341impl<T: DeviceRepr> DeviceMemory<T> {
342 pub fn view(&self) -> DeviceView<'_, T> {
343 DeviceView::from_memory(self)
344 }
345
346 pub fn view_mut(&mut self) -> DeviceViewMut<'_, T> {
347 DeviceViewMut::from_memory(self)
348 }
349}
350
351impl<T: DeviceRepr> DeviceSlice<T> for DeviceMemory<T> {
352 fn as_device_ptr(&self) -> *const T {
353 self.as_ptr()
354 }
355
356 fn len(&self) -> usize {
357 self.len()
358 }
359}
360
361impl<T: DeviceRepr> DeviceSliceMut<T> for DeviceMemory<T> {
362 fn as_device_mut_ptr(&mut self) -> *mut T {
363 self.as_mut_ptr()
364 }
365}
366
367impl<T: DeviceRepr> DeviceSlice<T> for DeviceView<'_, T> {
368 fn as_device_ptr(&self) -> *const T {
369 self.ptr
370 }
371
372 fn len(&self) -> usize {
373 self.length
374 }
375}
376
377impl<T: DeviceRepr> DeviceSlice<T> for DeviceViewMut<'_, T> {
378 fn as_device_ptr(&self) -> *const T {
379 self.ptr
380 }
381
382 fn len(&self) -> usize {
383 self.length
384 }
385}
386
387impl<T: DeviceRepr> DeviceSliceMut<T> for DeviceViewMut<'_, T> {
388 fn as_device_mut_ptr(&mut self) -> *mut T {
389 self.ptr
390 }
391}
392
393impl<T: DeviceRepr> HostSlice<T> for [T] {
394 fn as_host_ptr(&self) -> *const T {
395 self.as_ptr()
396 }
397
398 fn len(&self) -> usize {
399 self.len()
400 }
401}
402
403impl<T: DeviceRepr> HostSliceMut<T> for [T] {
404 fn as_host_mut_ptr(&mut self) -> *mut T {
405 self.as_mut_ptr()
406 }
407}
408
409impl<T: DeviceRepr, const N: usize> HostSlice<T> for [T; N] {
410 fn as_host_ptr(&self) -> *const T {
411 self.as_ptr()
412 }
413
414 fn len(&self) -> usize {
415 N
416 }
417}
418
419impl<T: DeviceRepr, const N: usize> HostSliceMut<T> for [T; N] {
420 fn as_host_mut_ptr(&mut self) -> *mut T {
421 self.as_mut_ptr()
422 }
423}
424
425impl<T: DeviceRepr> HostSlice<T> for Vec<T> {
426 fn as_host_ptr(&self) -> *const T {
427 self.as_ptr()
428 }
429
430 fn len(&self) -> usize {
431 self.len()
432 }
433}
434
435impl<T: DeviceRepr> HostSliceMut<T> for Vec<T> {
436 fn as_host_mut_ptr(&mut self) -> *mut T {
437 self.as_mut_ptr()
438 }
439}
440
441fn bounds_to_range<R: RangeBounds<usize>>(range: R, length: usize) -> Result<(usize, usize)> {
442 let start = match range.start_bound() {
443 Bound::Included(&value) => value,
444 Bound::Excluded(&value) => value.checked_add(1).ok_or(Error::InvalidMemoryAccess)?,
445 Bound::Unbounded => 0,
446 };
447 let end = match range.end_bound() {
448 Bound::Included(&value) => value.checked_add(1).ok_or(Error::InvalidMemoryAccess)?,
449 Bound::Excluded(&value) => value,
450 Bound::Unbounded => length,
451 };
452
453 if start > end || end > length {
454 return Err(Error::InvalidMemoryAccess);
455 }
456
457 Ok((start, end))
458}