Skip to main content

sapient_core/
buffer.rs

1//! `Buffer` trait and `CpuBuffer` — aligned, heap-allocated byte storage.
2
3use std::alloc::{self, Layout};
4use std::ptr::NonNull;
5use std::sync::Arc;
6
7use crate::dtype::DType;
8use crate::error::{Result, SapientError};
9
10// ── Buffer trait ─────────────────────────────────────────────────────────────
11
12/// A raw byte buffer backing a tensor.
13///
14/// Implementations may reside on CPU heap, GPU device memory, or shared
15/// (unified) memory. The trait is object-safe so backends can return
16/// `Arc<dyn Buffer>`.
17pub trait Buffer: Send + Sync + std::fmt::Debug {
18    /// Returns a raw byte slice over all elements.
19    fn as_bytes(&self) -> &[u8];
20
21    /// Returns a mutable raw byte slice (may be unavailable for GPU buffers).
22    fn as_bytes_mut(&mut self) -> &mut [u8];
23
24    /// Total capacity in bytes.
25    fn len(&self) -> usize;
26
27    /// True if the buffer has zero capacity.
28    fn is_empty(&self) -> bool {
29        self.len() == 0
30    }
31
32    /// Alignment used when this buffer was allocated.
33    fn alignment(&self) -> usize;
34
35    /// Textual description of where the buffer lives (e.g. "cpu", "metal").
36    fn device(&self) -> &str;
37}
38
39// ── BufferHandle ─────────────────────────────────────────────────────────────
40
41/// A reference-counted handle to a `Buffer`.
42#[derive(Debug, Clone)]
43pub struct BufferHandle(pub Arc<dyn Buffer>);
44
45impl BufferHandle {
46    pub fn new(buf: impl Buffer + 'static) -> Self {
47        Self(Arc::new(buf))
48    }
49
50    pub fn as_bytes(&self) -> &[u8] {
51        self.0.as_bytes()
52    }
53
54    pub fn len(&self) -> usize {
55        self.0.len()
56    }
57
58    pub fn is_empty(&self) -> bool {
59        self.0.is_empty()
60    }
61}
62
63// ── CpuBuffer ────────────────────────────────────────────────────────────────
64
65/// A properly-aligned CPU heap buffer.
66///
67/// Uses Rust's global allocator directly to guarantee alignment, which
68/// `Vec<u8>` cannot guarantee beyond its element alignment (1 byte).
69pub struct CpuBuffer {
70    ptr: NonNull<u8>,
71    len: usize,
72    align: usize,
73    layout: Layout,
74}
75
76// SAFETY: The raw pointer is owned exclusively by this struct; we implement
77// Send + Sync here because the data behind it is plain bytes.
78unsafe impl Send for CpuBuffer {}
79unsafe impl Sync for CpuBuffer {}
80
81impl std::fmt::Debug for CpuBuffer {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("CpuBuffer")
84            .field("len", &self.len)
85            .field("align", &self.align)
86            .finish()
87    }
88}
89
90impl CpuBuffer {
91    /// Allocate a zero-initialised buffer for `numel` elements of `dtype`.
92    pub fn zeros(numel: usize, dtype: DType) -> Result<Self> {
93        let bytes = dtype.byte_count(numel);
94        let align = dtype.alignment().max(64); // cache-line friendly
95        Self::with_capacity(bytes, align)
96    }
97
98    /// Allocate `bytes` bytes with the given alignment.
99    pub fn with_capacity(bytes: usize, align: usize) -> Result<Self> {
100        if bytes == 0 {
101            // Zero-size allocation: use a dangling-but-aligned pointer.
102            let layout = Layout::from_size_align(1, align)
103                .map_err(|_| SapientError::AllocationFailed { bytes, align })?;
104            let ptr = unsafe { alloc::alloc_zeroed(layout) };
105            let ptr = NonNull::new(ptr).ok_or(SapientError::AllocationFailed { bytes, align })?;
106            return Ok(Self {
107                ptr,
108                len: 0,
109                align,
110                layout,
111            });
112        }
113
114        let layout = Layout::from_size_align(bytes, align)
115            .map_err(|_| SapientError::AllocationFailed { bytes, align })?;
116
117        // SAFETY: layout is well-formed.
118        let raw = unsafe { alloc::alloc_zeroed(layout) };
119        let ptr = NonNull::new(raw).ok_or(SapientError::AllocationFailed { bytes, align })?;
120
121        Ok(Self {
122            ptr,
123            len: bytes,
124            align,
125            layout,
126        })
127    }
128
129    /// Wrap existing `f32` data.
130    pub fn from_f32_slice(data: &[f32]) -> Result<Self> {
131        let bytes = data.len() * 4;
132        let buf = Self::with_capacity(bytes, 64)?;
133        // SAFETY: sizes are consistent and both regions are valid.
134        unsafe {
135            std::ptr::copy_nonoverlapping(data.as_ptr() as *const u8, buf.ptr.as_ptr(), bytes);
136        }
137        Ok(buf)
138    }
139
140    /// Wrap existing raw bytes (e.g., native BF16 or F16 from safetensors).
141    pub fn from_bytes_slice(data: &[u8]) -> Result<Self> {
142        let bytes = data.len();
143        if bytes == 0 {
144            return Self::with_capacity(0, 16);
145        }
146        let buf = Self::with_capacity(bytes, 16)?;
147        // SAFETY: sizes are consistent and both regions are valid.
148        unsafe {
149            std::ptr::copy_nonoverlapping(data.as_ptr(), buf.ptr.as_ptr(), bytes);
150        }
151        Ok(buf)
152    }
153
154    /// View as `f32` slice (panics if not properly aligned/sized).
155    pub fn as_f32_slice(&self) -> &[f32] {
156        assert_eq!(self.len % 4, 0, "buffer length not a multiple of 4");
157        // SAFETY: alignment guaranteed, size checked above.
158        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const f32, self.len / 4) }
159    }
160
161    /// Mutable view as `f32` slice.
162    pub fn as_f32_slice_mut(&mut self) -> &mut [f32] {
163        assert_eq!(self.len % 4, 0, "buffer length not a multiple of 4");
164        // SAFETY: alignment guaranteed, exclusive via `&mut`.
165        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut f32, self.len / 4) }
166    }
167
168    /// Pointer to raw memory (useful for BLAS calls).
169    pub fn as_ptr(&self) -> *const u8 {
170        self.ptr.as_ptr()
171    }
172
173    pub fn as_mut_ptr(&mut self) -> *mut u8 {
174        self.ptr.as_ptr()
175    }
176}
177
178impl Buffer for CpuBuffer {
179    fn as_bytes(&self) -> &[u8] {
180        // SAFETY: ptr is valid for `len` bytes, exclusively owned.
181        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
182    }
183
184    fn as_bytes_mut(&mut self) -> &mut [u8] {
185        // SAFETY: `&mut self` guarantees exclusive access.
186        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
187    }
188
189    fn len(&self) -> usize {
190        self.len
191    }
192
193    fn alignment(&self) -> usize {
194        self.align
195    }
196
197    fn device(&self) -> &str {
198        "cpu"
199    }
200}
201
202impl Drop for CpuBuffer {
203    fn drop(&mut self) {
204        // SAFETY: layout matches the original allocation.
205        unsafe { alloc::dealloc(self.ptr.as_ptr(), self.layout) }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn zeros_and_read() {
215        let buf = CpuBuffer::zeros(4, DType::F32).unwrap();
216        assert_eq!(buf.len(), 16);
217        assert!(buf.as_bytes().iter().all(|&b| b == 0));
218    }
219
220    #[test]
221    fn from_f32_roundtrip() {
222        let data = vec![1.0f32, 2.0, 3.0, 4.0];
223        let buf = CpuBuffer::from_f32_slice(&data).unwrap();
224        assert_eq!(buf.as_f32_slice(), data.as_slice());
225    }
226
227    #[test]
228    fn alignment_guarantee() {
229        let buf = CpuBuffer::with_capacity(32, 64).unwrap();
230        assert_eq!(buf.as_ptr() as usize % 64, 0);
231    }
232}