1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::{Arc, OnceLock};
3
4use smallvec::{SmallVec, smallvec};
5use svod_dtype::DType;
6
7use snafu::ResultExt;
8use svod_dtype::ext::HasDType;
9
10use crate::allocator::{Allocator, BufferOptions, RawBuffer};
11use crate::error::{
12 InvalidViewSnafu, NdarrayShapeSnafu, NotCpuAccessibleSnafu, Result, SizeMismatchSnafu, TypeMismatchSnafu,
13};
14
15static BUFFER_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
20
21fn next_buffer_id() -> u64 {
22 BUFFER_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct BufferId(pub u64);
34
35#[cfg(feature = "cuda")]
36use crate::error::CudaSnafu;
37#[cfg(feature = "cuda")]
38use snafu::ResultExt;
39
40#[derive(Debug)]
42struct BufferData {
43 storage_id: BufferId,
50 raw: OnceLock<RawBuffer>,
52 allocator: Arc<dyn Allocator>,
53 total_size: usize,
55 options: BufferOptions,
57}
58
59impl BufferData {
60 fn new(allocator: Arc<dyn Allocator>, size: usize, options: BufferOptions) -> Self {
61 Self { storage_id: BufferId(next_buffer_id()), raw: OnceLock::new(), allocator, total_size: size, options }
62 }
63
64 fn ensure_allocated(&self) -> Result<()> {
67 if self.raw.get().is_some() {
68 return Ok(());
69 }
70
71 let raw = self.allocator.alloc(self.total_size, &self.options)?;
73
74 if let Err(raw) = self.raw.set(raw) {
76 self.allocator.free(raw, &self.options);
78 }
79
80 Ok(())
81 }
82
83 fn is_allocated(&self) -> bool {
85 self.raw.get().is_some()
86 }
87
88 fn raw(&self) -> &RawBuffer {
90 self.raw.get().expect("buffer not allocated")
91 }
92}
93
94impl Drop for BufferData {
95 fn drop(&mut self) {
96 if let Some(raw) = self.raw.take() {
98 self.allocator.free(raw, &self.options);
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
110pub struct Buffer {
111 id: BufferId,
114 data: Arc<BufferData>,
116 offset: usize,
118 size: usize,
120 dtype: DType,
122 shape: SmallVec<[usize; 4]>,
124}
125
126impl Buffer {
127 pub fn new(allocator: Arc<dyn Allocator>, dtype: DType, shape: Vec<usize>, options: BufferOptions) -> Self {
129 let size = dtype.bytes() * shape.iter().product::<usize>();
130 Self {
131 id: BufferId(next_buffer_id()),
132 data: Arc::new(BufferData::new(allocator, size, options)),
133 offset: 0,
134 size,
135 dtype,
136 shape: SmallVec::from_vec(shape),
137 }
138 }
139
140 pub fn allocate(
142 allocator: Arc<dyn Allocator>,
143 dtype: DType,
144 shape: Vec<usize>,
145 options: BufferOptions,
146 ) -> Result<Self> {
147 let buffer = Self::new(allocator, dtype, shape, options);
148 buffer.ensure_allocated()?;
149 Ok(buffer)
150 }
151
152 pub fn view(&self, offset: usize, size: usize) -> Result<Self> {
160 if offset + size > self.size {
162 return InvalidViewSnafu { offset, size, buffer_size: self.size }.fail();
163 }
164
165 Ok(Self {
166 id: BufferId(next_buffer_id()),
167 data: Arc::clone(&self.data),
168 offset: self.offset + offset,
169 size,
170 dtype: self.dtype.clone(),
171 shape: smallvec![size / self.dtype.bytes()],
173 })
174 }
175
176 pub fn ensure_allocated(&self) -> Result<()> {
178 self.data.ensure_allocated()
179 }
180
181 pub fn is_allocated(&self) -> bool {
183 self.data.is_allocated()
184 }
185
186 pub fn size(&self) -> usize {
188 self.size
189 }
190
191 pub fn offset(&self) -> usize {
193 self.offset
194 }
195
196 pub fn dtype(&self) -> DType {
198 self.dtype.clone()
199 }
200
201 pub fn shape(&self) -> &[usize] {
203 &self.shape
204 }
205
206 pub fn as_host_bytes(&self) -> Result<&[u8]> {
215 self.ensure_allocated()?;
216 let raw = self.data.raw();
217 match raw {
218 RawBuffer::Cpu { data, .. } => {
219 let bytes = unsafe { &(&(*data.get()))[self.offset..self.offset + self.size] };
223 Ok(bytes)
224 }
225 RawBuffer::Mmap { data, .. } => Ok(&data[self.offset..self.offset + self.size]),
226 #[cfg(feature = "cuda")]
227 _ => NotCpuAccessibleSnafu.fail(),
228 }
229 }
230
231 #[allow(clippy::mut_from_ref)] pub fn as_host_bytes_mut(&self) -> Result<&mut [u8]> {
241 self.ensure_allocated()?;
242 let raw = self.data.raw();
243 match raw {
244 RawBuffer::Cpu { data, .. } => {
245 let bytes = unsafe { &mut (&mut *data.get())[self.offset..self.offset + self.size] };
249 Ok(bytes)
250 }
251 RawBuffer::Mmap { .. } => NotCpuAccessibleSnafu.fail(),
253 #[cfg(feature = "cuda")]
254 _ => NotCpuAccessibleSnafu.fail(),
255 }
256 }
257
258 pub fn as_array<T: HasDType>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
268 self.ensure_allocated()?;
269 if self.dtype != T::DTYPE {
270 return TypeMismatchSnafu { expected: T::DTYPE, actual: self.dtype.clone() }.fail();
271 }
272 let raw = self.data.raw();
273 match raw {
274 RawBuffer::Cpu { data, .. } => {
275 let bytes = unsafe { &(&(*data.get()))[self.offset..self.offset + self.size] };
276 let count = bytes.len() / T::DTYPE.bytes();
277 let typed = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const T, count) };
278 ndarray::ArrayViewD::from_shape(ndarray::IxDyn(&self.shape), typed).context(NdarrayShapeSnafu)
279 }
280 RawBuffer::Mmap { data, .. } => {
281 let bytes = &data[self.offset..self.offset + self.size];
282 let count = bytes.len() / T::DTYPE.bytes();
283 let typed = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const T, count) };
284 ndarray::ArrayViewD::from_shape(ndarray::IxDyn(&self.shape), typed).context(NdarrayShapeSnafu)
285 }
286 #[cfg(feature = "cuda")]
287 _ => NotCpuAccessibleSnafu.fail(),
288 }
289 }
290
291 #[allow(clippy::mut_from_ref)]
300 pub fn as_array_mut<T: HasDType>(&self) -> Result<ndarray::ArrayViewMutD<'_, T>> {
301 self.ensure_allocated()?;
302 if self.dtype != T::DTYPE {
303 return TypeMismatchSnafu { expected: T::DTYPE, actual: self.dtype.clone() }.fail();
304 }
305 let raw = self.data.raw();
306 match raw {
307 RawBuffer::Cpu { data, cpu_accessible } if *cpu_accessible => {
308 let bytes = unsafe { &mut (&mut *data.get())[self.offset..self.offset + self.size] };
309 let count = bytes.len() / T::DTYPE.bytes();
310 let typed = unsafe { std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut T, count) };
311 ndarray::ArrayViewMutD::from_shape(ndarray::IxDyn(&self.shape), typed).context(NdarrayShapeSnafu)
312 }
313 _ => NotCpuAccessibleSnafu.fail(),
314 }
315 }
316
317 pub fn as_slice<T: HasDType>(&self) -> Result<&[T]> {
319 self.ensure_allocated()?;
320 if self.dtype != T::DTYPE {
321 return TypeMismatchSnafu { expected: T::DTYPE, actual: self.dtype.clone() }.fail();
322 }
323 let raw = self.data.raw();
324 match raw {
325 RawBuffer::Cpu { data, cpu_accessible } if *cpu_accessible => {
326 let bytes = unsafe { &(&(*data.get()))[self.offset..self.offset + self.size] };
327 let count = bytes.len() / T::DTYPE.bytes();
328 Ok(unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const T, count) })
329 }
330 _ => NotCpuAccessibleSnafu.fail(),
331 }
332 }
333
334 pub fn item<T: HasDType + Copy>(&self) -> Result<T> {
338 let slice = self.as_slice::<T>()?;
339 assert_eq!(slice.len(), 1, "item() requires exactly 1 element, got {}", slice.len());
340 Ok(slice[0])
341 }
342
343 pub fn allocator(&self) -> &dyn Allocator {
345 &*self.data.allocator
346 }
347
348 pub fn allocator_arc(&self) -> Arc<dyn Allocator> {
352 Arc::clone(&self.data.allocator)
353 }
354
355 pub fn id(&self) -> BufferId {
363 self.id
364 }
365
366 pub fn total_size(&self) -> usize {
372 self.data.total_size
373 }
374
375 pub fn storage_id(&self) -> BufferId {
383 self.data.storage_id
384 }
385
386 pub fn copyin(&mut self, src: &[u8]) -> Result<()> {
388 self.ensure_allocated()?;
389
390 let expected = self.size;
391 let actual = src.len();
392 snafu::ensure!(expected == actual, SizeMismatchSnafu { expected, actual });
393
394 let raw = self.data.raw();
395 match raw {
396 RawBuffer::Cpu { data, .. } => {
397 let slice = unsafe {
399 let data_mut = &mut *data.get();
400 &mut data_mut[self.offset..self.offset + self.size]
401 };
402 slice.copy_from_slice(src);
403 Ok(())
404 }
405 RawBuffer::Mmap { .. } => panic!("DISK device is read-only: copyin not supported"),
406 #[cfg(feature = "cuda")]
407 RawBuffer::CudaDevice { data, device } => {
408 let cuda_data = unsafe { &mut *data.get() };
410 let mut view = cuda_data.slice_mut(self.offset..self.offset + self.size);
411 device.default_stream().memcpy_htod(src, &mut view).context(CudaSnafu)
412 }
413 #[cfg(feature = "cuda")]
414 RawBuffer::CudaUnified { data, .. } => {
415 let unified_data = unsafe { &mut *data.get() };
417 let slice = unified_data.as_mut_slice().context(CudaSnafu)?;
418 let target = &mut slice[self.offset..self.offset + self.size];
419 target.copy_from_slice(src);
420 Ok(())
421 }
422 }
423 }
424
425 pub fn copyout(&self, dst: &mut [u8]) -> Result<()> {
427 self.ensure_allocated()?;
428
429 let expected = self.size;
430 let actual = dst.len();
431 snafu::ensure!(expected == actual, SizeMismatchSnafu { expected, actual });
432
433 let raw = self.data.raw();
434 match raw {
435 RawBuffer::Cpu { data, .. } => {
436 let data_ref = unsafe { &*data.get() };
438 dst.copy_from_slice(&data_ref[self.offset..self.offset + self.size]);
439 Ok(())
440 }
441 RawBuffer::Mmap { data, .. } => {
442 dst.copy_from_slice(&data[self.offset..self.offset + self.size]);
443 Ok(())
444 }
445 #[cfg(feature = "cuda")]
446 RawBuffer::CudaDevice { data, device } => {
447 device.synchronize().context(CudaSnafu)?;
448 let cuda_data = unsafe { &*data.get() };
450 let view = cuda_data.slice(self.offset..self.offset + self.size);
451 device.default_stream().memcpy_dtoh(&view, dst).context(CudaSnafu)
452 }
453 #[cfg(feature = "cuda")]
454 RawBuffer::CudaUnified { data, .. } => {
455 let unified_data = unsafe { &*data.get() };
457 let slice = unified_data.as_slice().context(CudaSnafu)?;
458 let source = &slice[self.offset..self.offset + self.size];
459 dst.copy_from_slice(source);
460 Ok(())
461 }
462 }
463 }
464
465 pub fn copy_from(&mut self, src: &Buffer) -> Result<()> {
467 self.ensure_allocated()?;
468 src.ensure_allocated()?;
469
470 let expected = self.size;
471 let actual = src.size;
472 snafu::ensure!(expected == actual, SizeMismatchSnafu { expected, actual });
473
474 let dst_raw = self.data.raw();
475 let src_raw = src.data.raw();
476
477 match (dst_raw, src_raw) {
480 (RawBuffer::Cpu { data: dst_data, .. }, RawBuffer::Cpu { data: src_data, .. }) => {
482 let dst_mut = unsafe { &mut *dst_data.get() };
483 let src_ref = unsafe { &*src_data.get() };
484 let dst_slice = &mut dst_mut[self.offset..self.offset + self.size];
485 let src_slice = &src_ref[src.offset..src.offset + src.size];
486 dst_slice.copy_from_slice(src_slice);
487 Ok(())
488 }
489 (RawBuffer::Cpu { data: dst_data, .. }, RawBuffer::Mmap { data: src_data, .. }) => {
491 let dst_mut = unsafe { &mut *dst_data.get() };
492 let dst_slice = &mut dst_mut[self.offset..self.offset + self.size];
493 let src_slice = &src_data[src.offset..src.offset + src.size];
494 dst_slice.copy_from_slice(src_slice);
495 Ok(())
496 }
497 (RawBuffer::Mmap { .. }, _) => panic!("DISK device is read-only: copy_from not supported"),
499 #[cfg(feature = "cuda")]
501 (
502 RawBuffer::CudaDevice { data: dst_data, device: dst_device },
503 RawBuffer::CudaDevice { data: src_data, .. },
504 ) => {
505 let dst_cuda = unsafe { &mut *dst_data.get() };
506 let src_cuda = unsafe { &*src_data.get() };
507 let mut dst_view = dst_cuda.slice_mut(self.offset..self.offset + self.size);
508 let src_view = src_cuda.slice(src.offset..src.offset + src.size);
509 dst_device.default_stream().memcpy_dtod(&src_view, &mut dst_view).context(CudaSnafu)
510 }
511 #[cfg(feature = "cuda")]
513 (RawBuffer::CudaDevice { data: dst_data, device }, RawBuffer::Cpu { data: src_data, .. }) => {
514 let dst_cuda = unsafe { &mut *dst_data.get() };
515 let src_ref = unsafe { &*src_data.get() };
516 let mut dst_view = dst_cuda.slice_mut(self.offset..self.offset + self.size);
517 let src_slice = &src_ref[src.offset..src.offset + src.size];
518 device.default_stream().memcpy_htod(src_slice, &mut dst_view).context(CudaSnafu)
519 }
520 #[cfg(feature = "cuda")]
522 (RawBuffer::Cpu { data: dst_data, .. }, RawBuffer::CudaDevice { data: src_data, device }) => {
523 let dst_mut = unsafe { &mut *dst_data.get() };
524 let src_cuda = unsafe { &*src_data.get() };
525 let dst_slice = &mut dst_mut[self.offset..self.offset + self.size];
526 let src_view = src_cuda.slice(src.offset..src.offset + src.size);
527 device.default_stream().memcpy_dtoh(&src_view, dst_slice).context(CudaSnafu)
528 }
529 #[cfg(feature = "cuda")]
531 (RawBuffer::CudaUnified { data: dst_data, .. }, RawBuffer::CudaUnified { data: src_data, .. }) => {
532 let dst_unified = unsafe { &mut *dst_data.get() };
533 let src_unified = unsafe { &*src_data.get() };
534 let dst_slice = dst_unified.as_mut_slice().context(CudaSnafu)?;
535 let src_slice = src_unified.as_slice().context(CudaSnafu)?;
536 let dst_target = &mut dst_slice[self.offset..self.offset + self.size];
537 let src_source = &src_slice[src.offset..src.offset + src.size];
538 dst_target.copy_from_slice(src_source);
539 Ok(())
540 }
541 #[cfg(feature = "cuda")]
543 (RawBuffer::CudaUnified { data: dst_data, .. }, RawBuffer::Cpu { data: src_data, .. }) => {
544 let dst_unified = unsafe { &mut *dst_data.get() };
545 let src_ref = unsafe { &*src_data.get() };
546 let dst_slice = dst_unified.as_mut_slice().context(CudaSnafu)?;
547 let dst_target = &mut dst_slice[self.offset..self.offset + self.size];
548 let src_source = &src_ref[src.offset..src.offset + src.size];
549 dst_target.copy_from_slice(src_source);
550 Ok(())
551 }
552 #[cfg(feature = "cuda")]
554 (RawBuffer::Cpu { data: dst_data, .. }, RawBuffer::CudaUnified { data: src_data, .. }) => {
555 let dst_mut = unsafe { &mut *dst_data.get() };
556 let src_unified = unsafe { &*src_data.get() };
557 let src_slice = src_unified.as_slice().context(CudaSnafu)?;
558 let dst_target = &mut dst_mut[self.offset..self.offset + self.size];
559 let src_source = &src_slice[src.offset..src.offset + src.size];
560 dst_target.copy_from_slice(src_source);
561 Ok(())
562 }
563 #[cfg(feature = "cuda")]
565 (
566 RawBuffer::CudaUnified { data: dst_data, device: dst_device },
567 RawBuffer::CudaDevice { data: src_data, .. },
568 ) => {
569 let src_cuda = unsafe { &*src_data.get() };
570 let src_view = src_cuda.slice(src.offset..src.offset + src.size);
571 let dst_unified = unsafe { &mut *dst_data.get() };
573 let mut dst_target = dst_unified.slice_mut(self.offset..self.offset + self.size);
574 dst_device.default_stream().memcpy_dtod(&src_view, &mut dst_target).context(CudaSnafu)
576 }
577 #[cfg(feature = "cuda")]
579 (RawBuffer::CudaDevice { data: dst_data, device }, RawBuffer::CudaUnified { data: src_data, .. }) => {
580 let dst_cuda = unsafe { &mut *dst_data.get() };
581 let mut dst_view = dst_cuda.slice_mut(self.offset..self.offset + self.size);
582 let src_unified = unsafe { &*src_data.get() };
584 let src_source = src_unified.slice(src.offset..src.offset + src.size);
585 device.default_stream().memcpy_htod(&src_source, &mut dst_view).context(CudaSnafu)
587 }
588 }
589 }
590
591 pub fn synchronize(&self) -> Result<()> {
593 self.data.allocator.synchronize()
594 }
595
596 pub unsafe fn as_raw_ptr(&self) -> *mut u8 {
610 let raw = self.data.raw();
611 match raw {
612 RawBuffer::Cpu { data, .. } => {
613 unsafe { (&mut *data.get()).as_mut_ptr().add(self.offset) }
616 }
617 RawBuffer::Mmap { data, .. } => {
618 unsafe { data.as_ptr().add(self.offset) as *mut u8 }
620 }
621 #[cfg(feature = "cuda")]
622 RawBuffer::CudaDevice { .. } | RawBuffer::CudaUnified { .. } => {
623 unimplemented!("CUDA buffer raw pointers not yet supported for kernel execution")
627 }
628 }
629 }
630
631 #[cfg(test)]
636 pub(crate) fn raw_data_ptr(&self) -> usize {
637 let raw = self.data.raw();
638 match raw {
639 RawBuffer::Cpu { data, .. } => {
640 unsafe { (*data.get()).as_ptr() as usize }
642 }
643 RawBuffer::Mmap { data, .. } => data.as_ptr() as usize,
644 #[cfg(feature = "cuda")]
645 RawBuffer::CudaDevice { data, .. } => {
646 unsafe { &*data.get() as *const _ as usize }
649 }
650 #[cfg(feature = "cuda")]
651 RawBuffer::CudaUnified { data, .. } => {
652 unsafe { &*data.get() as *const _ as usize }
655 }
656 }
657 }
658}