Skip to main content

torsh_backend/
buffer.rs

1//! Buffer management and memory operations
2
3use crate::Device;
4use torsh_core::{
5    dtype::DType,
6    error::{Result, TorshError},
7    shape::Shape,
8};
9
10#[cfg(not(feature = "std"))]
11use alloc::{boxed::Box, string::String, vec::Vec};
12
13#[cfg(not(feature = "std"))]
14use core::sync::atomic::{AtomicUsize, Ordering};
15#[cfg(feature = "std")]
16use std::sync::atomic::{AtomicUsize, Ordering};
17
18/// Global buffer ID generator
19static BUFFER_ID_COUNTER: AtomicUsize = AtomicUsize::new(1);
20
21/// Generate a unique buffer ID
22pub fn generate_buffer_id() -> usize {
23    BUFFER_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
24}
25
26/// Buffer handle representing device memory
27#[derive(Debug, Clone)]
28pub struct Buffer {
29    /// Unique buffer ID
30    pub id: usize,
31
32    /// Device this buffer belongs to
33    pub device: Device,
34
35    /// Buffer size in bytes
36    pub size: usize,
37
38    /// Buffer usage flags
39    pub usage: BufferUsage,
40
41    /// Buffer descriptor used for creation
42    pub descriptor: BufferDescriptor,
43
44    /// Backend-specific handle (opaque)
45    pub handle: BufferHandle,
46}
47
48impl Buffer {
49    /// Create a new buffer
50    pub fn new(
51        id: usize,
52        device: Device,
53        size: usize,
54        usage: BufferUsage,
55        descriptor: BufferDescriptor,
56        handle: BufferHandle,
57    ) -> Self {
58        Self {
59            id,
60            device,
61            size,
62            usage,
63            descriptor,
64            handle,
65        }
66    }
67
68    /// Get buffer ID
69    pub fn id(&self) -> usize {
70        self.id
71    }
72
73    /// Get the device this buffer belongs to
74    pub fn device(&self) -> &Device {
75        &self.device
76    }
77
78    /// Get buffer size in bytes
79    pub fn size(&self) -> usize {
80        self.size
81    }
82
83    /// Get buffer usage flags
84    pub fn usage(&self) -> BufferUsage {
85        self.usage
86    }
87
88    /// Get the backend-specific handle
89    pub fn handle(&self) -> &BufferHandle {
90        &self.handle
91    }
92
93    /// Check if buffer can be used for the given usage
94    pub fn supports_usage(&self, usage: BufferUsage) -> bool {
95        self.usage.contains(usage)
96    }
97}
98
99/// Buffer descriptor for creation
100#[derive(Debug, Clone, PartialEq)]
101pub struct BufferDescriptor {
102    /// Buffer size in bytes
103    pub size: usize,
104
105    /// Buffer usage flags
106    pub usage: BufferUsage,
107
108    /// Memory location hint
109    pub location: MemoryLocation,
110
111    /// Data type stored in buffer (for type safety)
112    pub dtype: Option<DType>,
113
114    /// Shape of data in buffer (for tensor operations)
115    pub shape: Option<Shape>,
116
117    /// Initial data to copy to buffer
118    pub initial_data: Option<Vec<u8>>,
119
120    /// Memory alignment requirement
121    pub alignment: Option<usize>,
122
123    /// Whether buffer should be zero-initialized
124    pub zero_init: bool,
125}
126
127impl BufferDescriptor {
128    /// Create a new buffer descriptor
129    pub fn new(size: usize, usage: BufferUsage) -> Self {
130        Self {
131            size,
132            usage,
133            location: MemoryLocation::Device,
134            dtype: None,
135            shape: None,
136            initial_data: None,
137            alignment: None,
138            zero_init: false,
139        }
140    }
141
142    /// Set memory location
143    pub fn with_location(mut self, location: MemoryLocation) -> Self {
144        self.location = location;
145        self
146    }
147
148    /// Set data type
149    pub fn with_dtype(mut self, dtype: DType) -> Self {
150        self.dtype = Some(dtype);
151        self
152    }
153
154    /// Set shape
155    pub fn with_shape(mut self, shape: Shape) -> Self {
156        self.shape = Some(shape);
157        self
158    }
159
160    /// Set initial data
161    pub fn with_initial_data(mut self, data: Vec<u8>) -> Self {
162        self.initial_data = Some(data);
163        self
164    }
165
166    /// Set alignment requirement
167    pub fn with_alignment(mut self, alignment: usize) -> Self {
168        self.alignment = Some(alignment);
169        self
170    }
171
172    /// Enable zero initialization
173    pub fn with_zero_init(mut self) -> Self {
174        self.zero_init = true;
175        self
176    }
177}
178
179/// Buffer usage flags
180#[derive(Debug, Clone, Copy, PartialEq, Eq)]
181pub struct BufferUsage {
182    bits: u32,
183}
184
185impl BufferUsage {
186    /// Empty usage flags
187    pub const NONE: Self = Self { bits: 0 };
188
189    /// Buffer can be read from
190    pub const READ: Self = Self { bits: 1 << 0 };
191
192    /// Buffer can be written to
193    pub const WRITE: Self = Self { bits: 1 << 1 };
194
195    /// Buffer can be used as storage (compute shader)
196    pub const STORAGE: Self = Self { bits: 1 << 2 };
197
198    /// Buffer can be used as uniform data
199    pub const UNIFORM: Self = Self { bits: 1 << 3 };
200
201    /// Buffer can be used as vertex data
202    pub const VERTEX: Self = Self { bits: 1 << 4 };
203
204    /// Buffer can be used as index data
205    pub const INDEX: Self = Self { bits: 1 << 5 };
206
207    /// Buffer can be copied from
208    pub const COPY_SRC: Self = Self { bits: 1 << 6 };
209
210    /// Buffer can be copied to
211    pub const COPY_DST: Self = Self { bits: 1 << 7 };
212
213    /// Buffer can be mapped for host access
214    pub const MAP_READ: Self = Self { bits: 1 << 8 };
215
216    /// Buffer can be mapped for host writing
217    pub const MAP_WRITE: Self = Self { bits: 1 << 9 };
218
219    /// Commonly used combinations
220    pub const READ_WRITE: Self = Self {
221        bits: Self::READ.bits | Self::WRITE.bits,
222    };
223    pub const STORAGE_READ_WRITE: Self = Self {
224        bits: Self::STORAGE.bits | Self::READ.bits | Self::WRITE.bits,
225    };
226
227    /// Create new usage flags
228    pub const fn new(bits: u32) -> Self {
229        Self { bits }
230    }
231
232    /// Check if usage contains the given flag
233    pub const fn contains(self, other: Self) -> bool {
234        (self.bits & other.bits) == other.bits
235    }
236
237    /// Combine with another usage flag
238    pub const fn union(self, other: Self) -> Self {
239        Self {
240            bits: self.bits | other.bits,
241        }
242    }
243
244    /// Remove a usage flag
245    pub const fn difference(self, other: Self) -> Self {
246        Self {
247            bits: self.bits & !other.bits,
248        }
249    }
250
251    /// Get the raw bits
252    pub const fn bits(self) -> u32 {
253        self.bits
254    }
255}
256
257impl std::ops::BitOr for BufferUsage {
258    type Output = Self;
259
260    fn bitor(self, rhs: Self) -> Self::Output {
261        self.union(rhs)
262    }
263}
264
265impl std::ops::BitOrAssign for BufferUsage {
266    fn bitor_assign(&mut self, rhs: Self) {
267        *self = *self | rhs;
268    }
269}
270
271/// Memory location hint
272#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
273pub enum MemoryLocation {
274    /// Device memory (GPU VRAM, etc.)
275    #[default]
276    Device,
277
278    /// Host memory (system RAM)
279    Host,
280
281    /// Unified memory (accessible from both host and device)
282    Unified,
283
284    /// Host memory that is cached by device
285    HostCached,
286
287    /// Device memory that is visible to host
288    DeviceHost,
289}
290
291/// Backend-specific buffer handle
292#[derive(Debug)]
293pub enum BufferHandle {
294    /// CPU buffer (raw pointer)
295    Cpu { ptr: *mut u8, size: usize },
296
297    /// CUDA buffer
298    #[cfg(feature = "cuda")]
299    Cuda { device_ptr: u64, size: usize },
300
301    /// Metal buffer
302    #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
303    Metal { buffer_id: u64, size: usize },
304
305    /// WebGPU buffer
306    #[cfg(feature = "webgpu")]
307    WebGpu { buffer_ptr: u64, size: usize },
308
309    /// Generic handle for custom backends
310    Generic {
311        handle: Box<dyn std::any::Any + Send + Sync>,
312        size: usize,
313    },
314}
315
316impl Clone for BufferHandle {
317    fn clone(&self) -> Self {
318        match self {
319            BufferHandle::Cpu { ptr, size } => BufferHandle::Cpu {
320                ptr: *ptr,
321                size: *size,
322            },
323            #[cfg(feature = "cuda")]
324            BufferHandle::Cuda { device_ptr, size } => BufferHandle::Cuda {
325                device_ptr: *device_ptr,
326                size: *size,
327            },
328            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
329            BufferHandle::Metal { buffer_id, size } => BufferHandle::Metal {
330                buffer_id: *buffer_id,
331                size: *size,
332            },
333            #[cfg(feature = "webgpu")]
334            BufferHandle::WebGpu { buffer_ptr, size } => BufferHandle::WebGpu {
335                buffer_ptr: *buffer_ptr,
336                size: *size,
337            },
338            BufferHandle::Generic { .. } => {
339                // For Generic handles, we can't actually clone the Box<dyn Any>
340                // This is a limitation - in practice, backends should avoid using Generic handles
341                // for buffers that need to be cloned
342                panic!("Cannot clone Generic buffer handles")
343            }
344        }
345    }
346}
347
348impl BufferHandle {
349    /// Get the size of the buffer
350    pub fn size(&self) -> usize {
351        match self {
352            BufferHandle::Cpu { size, .. } => *size,
353            #[cfg(feature = "cuda")]
354            BufferHandle::Cuda { size, .. } => *size,
355            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
356            BufferHandle::Metal { size, .. } => *size,
357            #[cfg(feature = "webgpu")]
358            BufferHandle::WebGpu { size, .. } => *size,
359            BufferHandle::Generic { size, .. } => *size,
360        }
361    }
362
363    /// Get a unique identifier for this buffer handle
364    pub fn id(&self) -> usize {
365        match self {
366            BufferHandle::Cpu { ptr, .. } => *ptr as usize,
367            #[cfg(feature = "cuda")]
368            BufferHandle::Cuda { device_ptr, .. } => *device_ptr as usize,
369            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
370            BufferHandle::Metal { buffer_id, .. } => *buffer_id as usize,
371            #[cfg(feature = "webgpu")]
372            BufferHandle::WebGpu { buffer_ptr, .. } => *buffer_ptr as usize,
373            BufferHandle::Generic { .. } => 0, // Generic handles don't have meaningful IDs
374        }
375    }
376
377    /// Check if handle is valid
378    pub fn is_valid(&self) -> bool {
379        match self {
380            BufferHandle::Cpu { ptr, size } => !ptr.is_null() && *size > 0,
381            #[cfg(feature = "cuda")]
382            BufferHandle::Cuda { device_ptr, size } => *device_ptr != 0 && *size > 0,
383            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
384            BufferHandle::Metal { buffer_id, size } => *buffer_id != 0 && *size > 0,
385            #[cfg(feature = "webgpu")]
386            BufferHandle::WebGpu { buffer_ptr, size } => *buffer_ptr != 0 && *size > 0,
387            BufferHandle::Generic { size, .. } => *size > 0,
388        }
389    }
390}
391
392// Note: BufferHandle should not implement Send/Sync automatically due to raw pointers
393// Individual backends should ensure thread safety
394unsafe impl Send for BufferHandle {}
395unsafe impl Sync for BufferHandle {}
396
397impl PartialEq for BufferHandle {
398    fn eq(&self, other: &Self) -> bool {
399        match (self, other) {
400            (
401                BufferHandle::Cpu {
402                    ptr: ptr1,
403                    size: size1,
404                },
405                BufferHandle::Cpu {
406                    ptr: ptr2,
407                    size: size2,
408                },
409            ) => ptr1 == ptr2 && size1 == size2,
410            #[cfg(feature = "cuda")]
411            (
412                BufferHandle::Cuda {
413                    device_ptr: ptr1,
414                    size: size1,
415                },
416                BufferHandle::Cuda {
417                    device_ptr: ptr2,
418                    size: size2,
419                },
420            ) => ptr1 == ptr2 && size1 == size2,
421            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
422            (
423                BufferHandle::Metal {
424                    buffer_id: id1,
425                    size: size1,
426                },
427                BufferHandle::Metal {
428                    buffer_id: id2,
429                    size: size2,
430                },
431            ) => id1 == id2 && size1 == size2,
432            #[cfg(feature = "webgpu")]
433            (
434                BufferHandle::WebGpu {
435                    buffer_ptr: ptr1,
436                    size: size1,
437                },
438                BufferHandle::WebGpu {
439                    buffer_ptr: ptr2,
440                    size: size2,
441                },
442            ) => ptr1 == ptr2 && size1 == size2,
443            (
444                BufferHandle::Generic { size: size1, .. },
445                BufferHandle::Generic { size: size2, .. },
446            ) => {
447                // For Generic handles, we can only compare sizes
448                size1 == size2
449            }
450            _ => false,
451        }
452    }
453}
454
455impl Eq for BufferHandle {}
456
457impl std::hash::Hash for BufferHandle {
458    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
459        match self {
460            BufferHandle::Cpu { ptr, size } => {
461                0u8.hash(state); // discriminant
462                (*ptr as usize).hash(state);
463                size.hash(state);
464            }
465            #[cfg(feature = "cuda")]
466            BufferHandle::Cuda { device_ptr, size } => {
467                1u8.hash(state); // discriminant
468                device_ptr.hash(state);
469                size.hash(state);
470            }
471            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
472            BufferHandle::Metal { buffer_id, size } => {
473                2u8.hash(state); // discriminant
474                buffer_id.hash(state);
475                size.hash(state);
476            }
477            #[cfg(feature = "webgpu")]
478            BufferHandle::WebGpu { buffer_ptr, size } => {
479                3u8.hash(state); // discriminant
480                buffer_ptr.hash(state);
481                size.hash(state);
482            }
483            BufferHandle::Generic { size, .. } => {
484                4u8.hash(state); // discriminant
485                size.hash(state);
486            }
487        }
488    }
489}
490
491/// Buffer view for sub-buffer operations
492#[derive(Debug)]
493pub struct BufferView {
494    /// Parent buffer
495    pub buffer: Buffer,
496
497    /// Offset in bytes from start of buffer
498    pub offset: usize,
499
500    /// Size of the view in bytes
501    pub size: usize,
502
503    /// Data type for typed views
504    pub dtype: Option<DType>,
505
506    /// Shape for tensor views
507    pub shape: Option<Shape>,
508}
509
510impl BufferView {
511    /// Create a new buffer view
512    pub fn new(buffer: Buffer, offset: usize, size: usize) -> Result<Self> {
513        if offset + size > buffer.size {
514            return Err(TorshError::InvalidArgument(
515                "Buffer view exceeds buffer bounds".to_string(),
516            ));
517        }
518
519        Ok(Self {
520            buffer,
521            offset,
522            size,
523            dtype: None,
524            shape: None,
525        })
526    }
527
528    /// Create a typed buffer view
529    pub fn typed(mut self, dtype: DType) -> Self {
530        self.dtype = Some(dtype);
531        self
532    }
533
534    /// Create a tensor buffer view
535    pub fn shaped(mut self, shape: Shape) -> Self {
536        self.shape = Some(shape);
537        self
538    }
539
540    /// Get the underlying buffer
541    pub fn buffer(&self) -> &Buffer {
542        &self.buffer
543    }
544
545    /// Get the offset
546    pub fn offset(&self) -> usize {
547        self.offset
548    }
549
550    /// Get the size
551    pub fn size(&self) -> usize {
552        self.size
553    }
554
555    /// Get the end offset
556    pub fn end_offset(&self) -> usize {
557        self.offset + self.size
558    }
559}
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564    use crate::device::{Device, DeviceInfo};
565    use torsh_core::{device::DeviceType, dtype::DType, shape::Shape};
566
567    fn create_test_device() -> Device {
568        let info = DeviceInfo::default();
569        Device::new(0, DeviceType::Cpu, "Test CPU".to_string(), info)
570    }
571
572    #[test]
573    fn test_buffer_descriptor_creation() {
574        let desc = BufferDescriptor::new(1024, BufferUsage::READ_WRITE);
575
576        assert_eq!(desc.size, 1024);
577        assert_eq!(desc.usage, BufferUsage::READ_WRITE);
578        assert_eq!(desc.location, MemoryLocation::Device);
579        assert_eq!(desc.dtype, None);
580        assert_eq!(desc.shape, None);
581        assert_eq!(desc.initial_data, None);
582        assert_eq!(desc.alignment, None);
583        assert!(!desc.zero_init);
584    }
585
586    #[test]
587    fn test_buffer_descriptor_builder() {
588        let desc = BufferDescriptor::new(2048, BufferUsage::STORAGE)
589            .with_location(MemoryLocation::Host)
590            .with_dtype(DType::F32)
591            .with_shape(Shape::new(vec![64, 32]))
592            .with_alignment(64)
593            .with_zero_init();
594
595        assert_eq!(desc.size, 2048);
596        assert_eq!(desc.usage, BufferUsage::STORAGE);
597        assert_eq!(desc.location, MemoryLocation::Host);
598        assert_eq!(desc.dtype, Some(DType::F32));
599        assert!(desc.shape.is_some());
600        assert_eq!(desc.alignment, Some(64));
601        assert!(desc.zero_init);
602    }
603
604    #[test]
605    fn test_buffer_usage_flags() {
606        let usage = BufferUsage::READ | BufferUsage::WRITE;
607        assert!(usage.contains(BufferUsage::READ));
608        assert!(usage.contains(BufferUsage::WRITE));
609        assert!(!usage.contains(BufferUsage::STORAGE));
610
611        let combined = BufferUsage::STORAGE_READ_WRITE;
612        assert!(combined.contains(BufferUsage::STORAGE));
613        assert!(combined.contains(BufferUsage::READ));
614        assert!(combined.contains(BufferUsage::WRITE));
615    }
616
617    #[test]
618    fn test_buffer_handle_validation() {
619        let handle_valid = BufferHandle::Cpu {
620            ptr: 0x1000 as *mut u8,
621            size: 1024,
622        };
623        assert!(handle_valid.is_valid());
624        assert_eq!(handle_valid.size(), 1024);
625
626        let handle_invalid = BufferHandle::Cpu {
627            ptr: std::ptr::null_mut(),
628            size: 1024,
629        };
630        assert!(!handle_invalid.is_valid());
631    }
632
633    #[test]
634    fn test_buffer_creation() {
635        let device = create_test_device();
636        let desc = BufferDescriptor::new(512, BufferUsage::READ_WRITE);
637        let handle = BufferHandle::Cpu {
638            ptr: 0x2000 as *mut u8,
639            size: 512,
640        };
641
642        let buffer = Buffer::new(
643            1,
644            device.clone(),
645            512,
646            BufferUsage::READ_WRITE,
647            desc.clone(),
648            handle,
649        );
650
651        assert_eq!(buffer.id(), 1);
652        assert_eq!(buffer.size(), 512);
653        assert_eq!(buffer.usage(), BufferUsage::READ_WRITE);
654        assert_eq!(buffer.device().id(), device.id());
655        assert!(buffer.supports_usage(BufferUsage::READ));
656        assert!(buffer.supports_usage(BufferUsage::WRITE));
657        assert!(!buffer.supports_usage(BufferUsage::STORAGE));
658    }
659
660    #[test]
661    fn test_buffer_view_creation() {
662        let device = create_test_device();
663        let desc = BufferDescriptor::new(1024, BufferUsage::READ_WRITE);
664        let handle = BufferHandle::Cpu {
665            ptr: 0x3000 as *mut u8,
666            size: 1024,
667        };
668
669        let buffer = Buffer::new(1, device, 1024, BufferUsage::READ_WRITE, desc, handle);
670
671        // Valid buffer view
672        let view = BufferView::new(buffer, 256, 512).unwrap();
673        assert_eq!(view.offset(), 256);
674        assert_eq!(view.size(), 512);
675        assert_eq!(view.end_offset(), 768);
676
677        // Test with a new buffer for invalid case
678        let device2 = create_test_device();
679        let desc2 = BufferDescriptor::new(1024, BufferUsage::READ_WRITE);
680        let handle2 = BufferHandle::Cpu {
681            ptr: 0x3001 as *mut u8,
682            size: 1024,
683        };
684        let buffer2 = Buffer::new(2, device2, 1024, BufferUsage::READ_WRITE, desc2, handle2);
685        let invalid_view = BufferView::new(buffer2, 800, 512);
686        assert!(invalid_view.is_err());
687    }
688
689    #[test]
690    fn test_buffer_view_typed() {
691        let device = create_test_device();
692        let desc = BufferDescriptor::new(1024, BufferUsage::READ_WRITE);
693        let handle = BufferHandle::Cpu {
694            ptr: 0x4000 as *mut u8,
695            size: 1024,
696        };
697
698        let buffer = Buffer::new(1, device, 1024, BufferUsage::READ_WRITE, desc, handle);
699        let view = BufferView::new(buffer, 0, 1024)
700            .unwrap()
701            .typed(DType::F32)
702            .shaped(Shape::new(vec![256])); // 256 f32 values = 1024 bytes
703
704        assert_eq!(view.dtype, Some(DType::F32));
705        assert!(view.shape.is_some());
706    }
707
708    #[test]
709    fn test_memory_location_variants() {
710        assert_eq!(MemoryLocation::default(), MemoryLocation::Device);
711
712        let locations = [
713            MemoryLocation::Device,
714            MemoryLocation::Host,
715            MemoryLocation::Unified,
716            MemoryLocation::HostCached,
717            MemoryLocation::DeviceHost,
718        ];
719
720        for location in locations {
721            let desc = BufferDescriptor::new(1024, BufferUsage::READ_WRITE).with_location(location);
722            assert_eq!(desc.location, location);
723        }
724    }
725}