1use crate::Device;
4use torsh_core::{
5 dtype::DType,
6 error::{Result, TorshError},
7 shape::Shape,
8};
9
10#[cfg(not(feature = "std"))]
11use alloc::{boxed::Box, string::String, vec::Vec};
12
13#[cfg(not(feature = "std"))]
14use core::sync::atomic::{AtomicUsize, Ordering};
15#[cfg(feature = "std")]
16use std::sync::atomic::{AtomicUsize, Ordering};
17
18static BUFFER_ID_COUNTER: AtomicUsize = AtomicUsize::new(1);
20
21pub fn generate_buffer_id() -> usize {
23 BUFFER_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
24}
25
26#[derive(Debug, Clone)]
28pub struct Buffer {
29 pub id: usize,
31
32 pub device: Device,
34
35 pub size: usize,
37
38 pub usage: BufferUsage,
40
41 pub descriptor: BufferDescriptor,
43
44 pub handle: BufferHandle,
46}
47
48impl Buffer {
49 pub fn new(
51 id: usize,
52 device: Device,
53 size: usize,
54 usage: BufferUsage,
55 descriptor: BufferDescriptor,
56 handle: BufferHandle,
57 ) -> Self {
58 Self {
59 id,
60 device,
61 size,
62 usage,
63 descriptor,
64 handle,
65 }
66 }
67
68 pub fn id(&self) -> usize {
70 self.id
71 }
72
73 pub fn device(&self) -> &Device {
75 &self.device
76 }
77
78 pub fn size(&self) -> usize {
80 self.size
81 }
82
83 pub fn usage(&self) -> BufferUsage {
85 self.usage
86 }
87
88 pub fn handle(&self) -> &BufferHandle {
90 &self.handle
91 }
92
93 pub fn supports_usage(&self, usage: BufferUsage) -> bool {
95 self.usage.contains(usage)
96 }
97}
98
99#[derive(Debug, Clone, PartialEq)]
101pub struct BufferDescriptor {
102 pub size: usize,
104
105 pub usage: BufferUsage,
107
108 pub location: MemoryLocation,
110
111 pub dtype: Option<DType>,
113
114 pub shape: Option<Shape>,
116
117 pub initial_data: Option<Vec<u8>>,
119
120 pub alignment: Option<usize>,
122
123 pub zero_init: bool,
125}
126
127impl BufferDescriptor {
128 pub fn new(size: usize, usage: BufferUsage) -> Self {
130 Self {
131 size,
132 usage,
133 location: MemoryLocation::Device,
134 dtype: None,
135 shape: None,
136 initial_data: None,
137 alignment: None,
138 zero_init: false,
139 }
140 }
141
142 pub fn with_location(mut self, location: MemoryLocation) -> Self {
144 self.location = location;
145 self
146 }
147
148 pub fn with_dtype(mut self, dtype: DType) -> Self {
150 self.dtype = Some(dtype);
151 self
152 }
153
154 pub fn with_shape(mut self, shape: Shape) -> Self {
156 self.shape = Some(shape);
157 self
158 }
159
160 pub fn with_initial_data(mut self, data: Vec<u8>) -> Self {
162 self.initial_data = Some(data);
163 self
164 }
165
166 pub fn with_alignment(mut self, alignment: usize) -> Self {
168 self.alignment = Some(alignment);
169 self
170 }
171
172 pub fn with_zero_init(mut self) -> Self {
174 self.zero_init = true;
175 self
176 }
177}
178
179#[derive(Debug, Clone, Copy, PartialEq, Eq)]
181pub struct BufferUsage {
182 bits: u32,
183}
184
185impl BufferUsage {
186 pub const NONE: Self = Self { bits: 0 };
188
189 pub const READ: Self = Self { bits: 1 << 0 };
191
192 pub const WRITE: Self = Self { bits: 1 << 1 };
194
195 pub const STORAGE: Self = Self { bits: 1 << 2 };
197
198 pub const UNIFORM: Self = Self { bits: 1 << 3 };
200
201 pub const VERTEX: Self = Self { bits: 1 << 4 };
203
204 pub const INDEX: Self = Self { bits: 1 << 5 };
206
207 pub const COPY_SRC: Self = Self { bits: 1 << 6 };
209
210 pub const COPY_DST: Self = Self { bits: 1 << 7 };
212
213 pub const MAP_READ: Self = Self { bits: 1 << 8 };
215
216 pub const MAP_WRITE: Self = Self { bits: 1 << 9 };
218
219 pub const READ_WRITE: Self = Self {
221 bits: Self::READ.bits | Self::WRITE.bits,
222 };
223 pub const STORAGE_READ_WRITE: Self = Self {
224 bits: Self::STORAGE.bits | Self::READ.bits | Self::WRITE.bits,
225 };
226
227 pub const fn new(bits: u32) -> Self {
229 Self { bits }
230 }
231
232 pub const fn contains(self, other: Self) -> bool {
234 (self.bits & other.bits) == other.bits
235 }
236
237 pub const fn union(self, other: Self) -> Self {
239 Self {
240 bits: self.bits | other.bits,
241 }
242 }
243
244 pub const fn difference(self, other: Self) -> Self {
246 Self {
247 bits: self.bits & !other.bits,
248 }
249 }
250
251 pub const fn bits(self) -> u32 {
253 self.bits
254 }
255}
256
257impl std::ops::BitOr for BufferUsage {
258 type Output = Self;
259
260 fn bitor(self, rhs: Self) -> Self::Output {
261 self.union(rhs)
262 }
263}
264
265impl std::ops::BitOrAssign for BufferUsage {
266 fn bitor_assign(&mut self, rhs: Self) {
267 *self = *self | rhs;
268 }
269}
270
271#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
273pub enum MemoryLocation {
274 #[default]
276 Device,
277
278 Host,
280
281 Unified,
283
284 HostCached,
286
287 DeviceHost,
289}
290
291#[derive(Debug)]
293pub enum BufferHandle {
294 Cpu { ptr: *mut u8, size: usize },
296
297 #[cfg(feature = "cuda")]
299 Cuda { device_ptr: u64, size: usize },
300
301 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
303 Metal { buffer_id: u64, size: usize },
304
305 #[cfg(feature = "webgpu")]
307 WebGpu { buffer_ptr: u64, size: usize },
308
309 Generic {
311 handle: Box<dyn std::any::Any + Send + Sync>,
312 size: usize,
313 },
314}
315
316impl Clone for BufferHandle {
317 fn clone(&self) -> Self {
318 match self {
319 BufferHandle::Cpu { ptr, size } => BufferHandle::Cpu {
320 ptr: *ptr,
321 size: *size,
322 },
323 #[cfg(feature = "cuda")]
324 BufferHandle::Cuda { device_ptr, size } => BufferHandle::Cuda {
325 device_ptr: *device_ptr,
326 size: *size,
327 },
328 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
329 BufferHandle::Metal { buffer_id, size } => BufferHandle::Metal {
330 buffer_id: *buffer_id,
331 size: *size,
332 },
333 #[cfg(feature = "webgpu")]
334 BufferHandle::WebGpu { buffer_ptr, size } => BufferHandle::WebGpu {
335 buffer_ptr: *buffer_ptr,
336 size: *size,
337 },
338 BufferHandle::Generic { .. } => {
339 panic!("Cannot clone Generic buffer handles")
343 }
344 }
345 }
346}
347
348impl BufferHandle {
349 pub fn size(&self) -> usize {
351 match self {
352 BufferHandle::Cpu { size, .. } => *size,
353 #[cfg(feature = "cuda")]
354 BufferHandle::Cuda { size, .. } => *size,
355 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
356 BufferHandle::Metal { size, .. } => *size,
357 #[cfg(feature = "webgpu")]
358 BufferHandle::WebGpu { size, .. } => *size,
359 BufferHandle::Generic { size, .. } => *size,
360 }
361 }
362
363 pub fn id(&self) -> usize {
365 match self {
366 BufferHandle::Cpu { ptr, .. } => *ptr as usize,
367 #[cfg(feature = "cuda")]
368 BufferHandle::Cuda { device_ptr, .. } => *device_ptr as usize,
369 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
370 BufferHandle::Metal { buffer_id, .. } => *buffer_id as usize,
371 #[cfg(feature = "webgpu")]
372 BufferHandle::WebGpu { buffer_ptr, .. } => *buffer_ptr as usize,
373 BufferHandle::Generic { .. } => 0, }
375 }
376
377 pub fn is_valid(&self) -> bool {
379 match self {
380 BufferHandle::Cpu { ptr, size } => !ptr.is_null() && *size > 0,
381 #[cfg(feature = "cuda")]
382 BufferHandle::Cuda { device_ptr, size } => *device_ptr != 0 && *size > 0,
383 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
384 BufferHandle::Metal { buffer_id, size } => *buffer_id != 0 && *size > 0,
385 #[cfg(feature = "webgpu")]
386 BufferHandle::WebGpu { buffer_ptr, size } => *buffer_ptr != 0 && *size > 0,
387 BufferHandle::Generic { size, .. } => *size > 0,
388 }
389 }
390}
391
392unsafe impl Send for BufferHandle {}
395unsafe impl Sync for BufferHandle {}
396
397impl PartialEq for BufferHandle {
398 fn eq(&self, other: &Self) -> bool {
399 match (self, other) {
400 (
401 BufferHandle::Cpu {
402 ptr: ptr1,
403 size: size1,
404 },
405 BufferHandle::Cpu {
406 ptr: ptr2,
407 size: size2,
408 },
409 ) => ptr1 == ptr2 && size1 == size2,
410 #[cfg(feature = "cuda")]
411 (
412 BufferHandle::Cuda {
413 device_ptr: ptr1,
414 size: size1,
415 },
416 BufferHandle::Cuda {
417 device_ptr: ptr2,
418 size: size2,
419 },
420 ) => ptr1 == ptr2 && size1 == size2,
421 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
422 (
423 BufferHandle::Metal {
424 buffer_id: id1,
425 size: size1,
426 },
427 BufferHandle::Metal {
428 buffer_id: id2,
429 size: size2,
430 },
431 ) => id1 == id2 && size1 == size2,
432 #[cfg(feature = "webgpu")]
433 (
434 BufferHandle::WebGpu {
435 buffer_ptr: ptr1,
436 size: size1,
437 },
438 BufferHandle::WebGpu {
439 buffer_ptr: ptr2,
440 size: size2,
441 },
442 ) => ptr1 == ptr2 && size1 == size2,
443 (
444 BufferHandle::Generic { size: size1, .. },
445 BufferHandle::Generic { size: size2, .. },
446 ) => {
447 size1 == size2
449 }
450 _ => false,
451 }
452 }
453}
454
455impl Eq for BufferHandle {}
456
457impl std::hash::Hash for BufferHandle {
458 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
459 match self {
460 BufferHandle::Cpu { ptr, size } => {
461 0u8.hash(state); (*ptr as usize).hash(state);
463 size.hash(state);
464 }
465 #[cfg(feature = "cuda")]
466 BufferHandle::Cuda { device_ptr, size } => {
467 1u8.hash(state); device_ptr.hash(state);
469 size.hash(state);
470 }
471 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
472 BufferHandle::Metal { buffer_id, size } => {
473 2u8.hash(state); buffer_id.hash(state);
475 size.hash(state);
476 }
477 #[cfg(feature = "webgpu")]
478 BufferHandle::WebGpu { buffer_ptr, size } => {
479 3u8.hash(state); buffer_ptr.hash(state);
481 size.hash(state);
482 }
483 BufferHandle::Generic { size, .. } => {
484 4u8.hash(state); size.hash(state);
486 }
487 }
488 }
489}
490
491#[derive(Debug)]
493pub struct BufferView {
494 pub buffer: Buffer,
496
497 pub offset: usize,
499
500 pub size: usize,
502
503 pub dtype: Option<DType>,
505
506 pub shape: Option<Shape>,
508}
509
510impl BufferView {
511 pub fn new(buffer: Buffer, offset: usize, size: usize) -> Result<Self> {
513 if offset + size > buffer.size {
514 return Err(TorshError::InvalidArgument(
515 "Buffer view exceeds buffer bounds".to_string(),
516 ));
517 }
518
519 Ok(Self {
520 buffer,
521 offset,
522 size,
523 dtype: None,
524 shape: None,
525 })
526 }
527
528 pub fn typed(mut self, dtype: DType) -> Self {
530 self.dtype = Some(dtype);
531 self
532 }
533
534 pub fn shaped(mut self, shape: Shape) -> Self {
536 self.shape = Some(shape);
537 self
538 }
539
540 pub fn buffer(&self) -> &Buffer {
542 &self.buffer
543 }
544
545 pub fn offset(&self) -> usize {
547 self.offset
548 }
549
550 pub fn size(&self) -> usize {
552 self.size
553 }
554
555 pub fn end_offset(&self) -> usize {
557 self.offset + self.size
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use crate::device::{Device, DeviceInfo};
565 use torsh_core::{device::DeviceType, dtype::DType, shape::Shape};
566
567 fn create_test_device() -> Device {
568 let info = DeviceInfo::default();
569 Device::new(0, DeviceType::Cpu, "Test CPU".to_string(), info)
570 }
571
572 #[test]
573 fn test_buffer_descriptor_creation() {
574 let desc = BufferDescriptor::new(1024, BufferUsage::READ_WRITE);
575
576 assert_eq!(desc.size, 1024);
577 assert_eq!(desc.usage, BufferUsage::READ_WRITE);
578 assert_eq!(desc.location, MemoryLocation::Device);
579 assert_eq!(desc.dtype, None);
580 assert_eq!(desc.shape, None);
581 assert_eq!(desc.initial_data, None);
582 assert_eq!(desc.alignment, None);
583 assert!(!desc.zero_init);
584 }
585
586 #[test]
587 fn test_buffer_descriptor_builder() {
588 let desc = BufferDescriptor::new(2048, BufferUsage::STORAGE)
589 .with_location(MemoryLocation::Host)
590 .with_dtype(DType::F32)
591 .with_shape(Shape::new(vec![64, 32]))
592 .with_alignment(64)
593 .with_zero_init();
594
595 assert_eq!(desc.size, 2048);
596 assert_eq!(desc.usage, BufferUsage::STORAGE);
597 assert_eq!(desc.location, MemoryLocation::Host);
598 assert_eq!(desc.dtype, Some(DType::F32));
599 assert!(desc.shape.is_some());
600 assert_eq!(desc.alignment, Some(64));
601 assert!(desc.zero_init);
602 }
603
604 #[test]
605 fn test_buffer_usage_flags() {
606 let usage = BufferUsage::READ | BufferUsage::WRITE;
607 assert!(usage.contains(BufferUsage::READ));
608 assert!(usage.contains(BufferUsage::WRITE));
609 assert!(!usage.contains(BufferUsage::STORAGE));
610
611 let combined = BufferUsage::STORAGE_READ_WRITE;
612 assert!(combined.contains(BufferUsage::STORAGE));
613 assert!(combined.contains(BufferUsage::READ));
614 assert!(combined.contains(BufferUsage::WRITE));
615 }
616
617 #[test]
618 fn test_buffer_handle_validation() {
619 let handle_valid = BufferHandle::Cpu {
620 ptr: 0x1000 as *mut u8,
621 size: 1024,
622 };
623 assert!(handle_valid.is_valid());
624 assert_eq!(handle_valid.size(), 1024);
625
626 let handle_invalid = BufferHandle::Cpu {
627 ptr: std::ptr::null_mut(),
628 size: 1024,
629 };
630 assert!(!handle_invalid.is_valid());
631 }
632
633 #[test]
634 fn test_buffer_creation() {
635 let device = create_test_device();
636 let desc = BufferDescriptor::new(512, BufferUsage::READ_WRITE);
637 let handle = BufferHandle::Cpu {
638 ptr: 0x2000 as *mut u8,
639 size: 512,
640 };
641
642 let buffer = Buffer::new(
643 1,
644 device.clone(),
645 512,
646 BufferUsage::READ_WRITE,
647 desc.clone(),
648 handle,
649 );
650
651 assert_eq!(buffer.id(), 1);
652 assert_eq!(buffer.size(), 512);
653 assert_eq!(buffer.usage(), BufferUsage::READ_WRITE);
654 assert_eq!(buffer.device().id(), device.id());
655 assert!(buffer.supports_usage(BufferUsage::READ));
656 assert!(buffer.supports_usage(BufferUsage::WRITE));
657 assert!(!buffer.supports_usage(BufferUsage::STORAGE));
658 }
659
660 #[test]
661 fn test_buffer_view_creation() {
662 let device = create_test_device();
663 let desc = BufferDescriptor::new(1024, BufferUsage::READ_WRITE);
664 let handle = BufferHandle::Cpu {
665 ptr: 0x3000 as *mut u8,
666 size: 1024,
667 };
668
669 let buffer = Buffer::new(1, device, 1024, BufferUsage::READ_WRITE, desc, handle);
670
671 let view = BufferView::new(buffer, 256, 512).unwrap();
673 assert_eq!(view.offset(), 256);
674 assert_eq!(view.size(), 512);
675 assert_eq!(view.end_offset(), 768);
676
677 let device2 = create_test_device();
679 let desc2 = BufferDescriptor::new(1024, BufferUsage::READ_WRITE);
680 let handle2 = BufferHandle::Cpu {
681 ptr: 0x3001 as *mut u8,
682 size: 1024,
683 };
684 let buffer2 = Buffer::new(2, device2, 1024, BufferUsage::READ_WRITE, desc2, handle2);
685 let invalid_view = BufferView::new(buffer2, 800, 512);
686 assert!(invalid_view.is_err());
687 }
688
689 #[test]
690 fn test_buffer_view_typed() {
691 let device = create_test_device();
692 let desc = BufferDescriptor::new(1024, BufferUsage::READ_WRITE);
693 let handle = BufferHandle::Cpu {
694 ptr: 0x4000 as *mut u8,
695 size: 1024,
696 };
697
698 let buffer = Buffer::new(1, device, 1024, BufferUsage::READ_WRITE, desc, handle);
699 let view = BufferView::new(buffer, 0, 1024)
700 .unwrap()
701 .typed(DType::F32)
702 .shaped(Shape::new(vec![256])); assert_eq!(view.dtype, Some(DType::F32));
705 assert!(view.shape.is_some());
706 }
707
708 #[test]
709 fn test_memory_location_variants() {
710 assert_eq!(MemoryLocation::default(), MemoryLocation::Device);
711
712 let locations = [
713 MemoryLocation::Device,
714 MemoryLocation::Host,
715 MemoryLocation::Unified,
716 MemoryLocation::HostCached,
717 MemoryLocation::DeviceHost,
718 ];
719
720 for location in locations {
721 let desc = BufferDescriptor::new(1024, BufferUsage::READ_WRITE).with_location(location);
722 assert_eq!(desc.location, location);
723 }
724 }
725}