1use std::alloc::{self, Layout};
4use std::ptr::NonNull;
5use std::sync::Arc;
6
7use crate::dtype::DType;
8use crate::error::{Result, SapientError};
9
10pub trait Buffer: Send + Sync + std::fmt::Debug {
18 fn as_bytes(&self) -> &[u8];
20
21 fn as_bytes_mut(&mut self) -> &mut [u8];
23
24 fn len(&self) -> usize;
26
27 fn is_empty(&self) -> bool {
29 self.len() == 0
30 }
31
32 fn alignment(&self) -> usize;
34
35 fn device(&self) -> &str;
37}
38
39#[derive(Debug, Clone)]
43pub struct BufferHandle(pub Arc<dyn Buffer>);
44
45impl BufferHandle {
46 pub fn new(buf: impl Buffer + 'static) -> Self {
47 Self(Arc::new(buf))
48 }
49
50 pub fn as_bytes(&self) -> &[u8] {
51 self.0.as_bytes()
52 }
53
54 pub fn len(&self) -> usize {
55 self.0.len()
56 }
57
58 pub fn is_empty(&self) -> bool {
59 self.0.is_empty()
60 }
61}
62
63pub struct CpuBuffer {
70 ptr: NonNull<u8>,
71 len: usize,
72 align: usize,
73 layout: Layout,
74}
75
76unsafe impl Send for CpuBuffer {}
79unsafe impl Sync for CpuBuffer {}
80
81impl std::fmt::Debug for CpuBuffer {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("CpuBuffer")
84 .field("len", &self.len)
85 .field("align", &self.align)
86 .finish()
87 }
88}
89
90impl CpuBuffer {
91 pub fn zeros(numel: usize, dtype: DType) -> Result<Self> {
93 let bytes = dtype.byte_count(numel);
94 let align = dtype.alignment().max(64); Self::with_capacity(bytes, align)
96 }
97
98 pub fn with_capacity(bytes: usize, align: usize) -> Result<Self> {
100 if bytes == 0 {
101 let layout = Layout::from_size_align(1, align)
103 .map_err(|_| SapientError::AllocationFailed { bytes, align })?;
104 let ptr = unsafe { alloc::alloc_zeroed(layout) };
105 let ptr = NonNull::new(ptr).ok_or(SapientError::AllocationFailed { bytes, align })?;
106 return Ok(Self {
107 ptr,
108 len: 0,
109 align,
110 layout,
111 });
112 }
113
114 let layout = Layout::from_size_align(bytes, align)
115 .map_err(|_| SapientError::AllocationFailed { bytes, align })?;
116
117 let raw = unsafe { alloc::alloc_zeroed(layout) };
119 let ptr = NonNull::new(raw).ok_or(SapientError::AllocationFailed { bytes, align })?;
120
121 Ok(Self {
122 ptr,
123 len: bytes,
124 align,
125 layout,
126 })
127 }
128
129 pub fn from_f32_slice(data: &[f32]) -> Result<Self> {
131 let bytes = data.len() * 4;
132 let buf = Self::with_capacity(bytes, 64)?;
133 unsafe {
135 std::ptr::copy_nonoverlapping(data.as_ptr() as *const u8, buf.ptr.as_ptr(), bytes);
136 }
137 Ok(buf)
138 }
139
140 pub fn from_f32_vec(data: Vec<f32>) -> Result<Self> {
144 if data.is_empty() {
145 return Self::with_capacity(0, 4);
146 }
147 let len = data.len() * 4;
148 let layout = Layout::array::<f32>(data.len())
149 .map_err(|_| SapientError::AllocationFailed { bytes: len, align: 4 })?;
150 let ptr = data.as_ptr() as *mut u8;
151 std::mem::forget(data);
155 Ok(Self {
156 ptr: NonNull::new(ptr)
157 .ok_or(SapientError::AllocationFailed { bytes: len, align: 4 })?,
158 len,
159 align: std::mem::align_of::<f32>(),
160 layout,
161 })
162 }
163
164 pub fn from_bytes_slice(data: &[u8]) -> Result<Self> {
166 let bytes = data.len();
167 if bytes == 0 {
168 return Self::with_capacity(0, 16);
169 }
170 let buf = Self::with_capacity(bytes, 16)?;
171 unsafe {
173 std::ptr::copy_nonoverlapping(data.as_ptr(), buf.ptr.as_ptr(), bytes);
174 }
175 Ok(buf)
176 }
177
178 pub fn as_f32_slice(&self) -> &[f32] {
180 assert_eq!(self.len % 4, 0, "buffer length not a multiple of 4");
181 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const f32, self.len / 4) }
183 }
184
185 pub fn as_f32_slice_mut(&mut self) -> &mut [f32] {
187 assert_eq!(self.len % 4, 0, "buffer length not a multiple of 4");
188 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut f32, self.len / 4) }
190 }
191
192 pub fn as_ptr(&self) -> *const u8 {
194 self.ptr.as_ptr()
195 }
196
197 pub fn as_mut_ptr(&mut self) -> *mut u8 {
198 self.ptr.as_ptr()
199 }
200}
201
202impl Buffer for CpuBuffer {
203 fn as_bytes(&self) -> &[u8] {
204 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
206 }
207
208 fn as_bytes_mut(&mut self) -> &mut [u8] {
209 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
211 }
212
213 fn len(&self) -> usize {
214 self.len
215 }
216
217 fn alignment(&self) -> usize {
218 self.align
219 }
220
221 fn device(&self) -> &str {
222 "cpu"
223 }
224}
225
226impl Drop for CpuBuffer {
227 fn drop(&mut self) {
228 unsafe { alloc::dealloc(self.ptr.as_ptr(), self.layout) }
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn zeros_and_read() {
239 let buf = CpuBuffer::zeros(4, DType::F32).unwrap();
240 assert_eq!(buf.len(), 16);
241 assert!(buf.as_bytes().iter().all(|&b| b == 0));
242 }
243
244 #[test]
245 fn from_f32_roundtrip() {
246 let data = vec![1.0f32, 2.0, 3.0, 4.0];
247 let buf = CpuBuffer::from_f32_slice(&data).unwrap();
248 assert_eq!(buf.as_f32_slice(), data.as_slice());
249 }
250
251 #[test]
252 fn alignment_guarantee() {
253 let buf = CpuBuffer::with_capacity(32, 64).unwrap();
254 assert_eq!(buf.as_ptr() as usize % 64, 0);
255 }
256}