1use std::alloc::{self, Layout};
4use std::ptr::NonNull;
5use std::sync::Arc;
6
7use crate::dtype::DType;
8use crate::error::{Result, SapientError};
9
10pub trait Buffer: Send + Sync + std::fmt::Debug {
18 fn as_bytes(&self) -> &[u8];
20
21 fn as_bytes_mut(&mut self) -> &mut [u8];
23
24 fn len(&self) -> usize;
26
27 fn is_empty(&self) -> bool {
29 self.len() == 0
30 }
31
32 fn alignment(&self) -> usize;
34
35 fn device(&self) -> &str;
37}
38
39#[derive(Debug, Clone)]
43pub struct BufferHandle(pub Arc<dyn Buffer>);
44
45impl BufferHandle {
46 pub fn new(buf: impl Buffer + 'static) -> Self {
47 Self(Arc::new(buf))
48 }
49
50 pub fn as_bytes(&self) -> &[u8] {
51 self.0.as_bytes()
52 }
53
54 pub fn len(&self) -> usize {
55 self.0.len()
56 }
57
58 pub fn is_empty(&self) -> bool {
59 self.0.is_empty()
60 }
61}
62
63pub struct CpuBuffer {
70 ptr: NonNull<u8>,
71 len: usize,
72 align: usize,
73 layout: Layout,
74}
75
76unsafe impl Send for CpuBuffer {}
79unsafe impl Sync for CpuBuffer {}
80
81impl std::fmt::Debug for CpuBuffer {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("CpuBuffer")
84 .field("len", &self.len)
85 .field("align", &self.align)
86 .finish()
87 }
88}
89
90impl CpuBuffer {
91 pub fn zeros(numel: usize, dtype: DType) -> Result<Self> {
93 let bytes = dtype.byte_count(numel);
94 let align = dtype.alignment().max(64); Self::with_capacity(bytes, align)
96 }
97
98 pub fn with_capacity(bytes: usize, align: usize) -> Result<Self> {
100 if bytes == 0 {
101 let layout = Layout::from_size_align(1, align)
103 .map_err(|_| SapientError::AllocationFailed { bytes, align })?;
104 let ptr = unsafe { alloc::alloc_zeroed(layout) };
105 let ptr = NonNull::new(ptr).ok_or(SapientError::AllocationFailed { bytes, align })?;
106 return Ok(Self {
107 ptr,
108 len: 0,
109 align,
110 layout,
111 });
112 }
113
114 let layout = Layout::from_size_align(bytes, align)
115 .map_err(|_| SapientError::AllocationFailed { bytes, align })?;
116
117 let raw = unsafe { alloc::alloc_zeroed(layout) };
119 let ptr = NonNull::new(raw).ok_or(SapientError::AllocationFailed { bytes, align })?;
120
121 Ok(Self {
122 ptr,
123 len: bytes,
124 align,
125 layout,
126 })
127 }
128
129 pub fn from_f32_slice(data: &[f32]) -> Result<Self> {
131 let bytes = data.len() * 4;
132 let buf = Self::with_capacity(bytes, 64)?;
133 unsafe {
135 std::ptr::copy_nonoverlapping(data.as_ptr() as *const u8, buf.ptr.as_ptr(), bytes);
136 }
137 Ok(buf)
138 }
139
140 pub fn from_bytes_slice(data: &[u8]) -> Result<Self> {
142 let bytes = data.len();
143 if bytes == 0 {
144 return Self::with_capacity(0, 16);
145 }
146 let buf = Self::with_capacity(bytes, 16)?;
147 unsafe {
149 std::ptr::copy_nonoverlapping(data.as_ptr(), buf.ptr.as_ptr(), bytes);
150 }
151 Ok(buf)
152 }
153
154 pub fn as_f32_slice(&self) -> &[f32] {
156 assert_eq!(self.len % 4, 0, "buffer length not a multiple of 4");
157 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const f32, self.len / 4) }
159 }
160
161 pub fn as_f32_slice_mut(&mut self) -> &mut [f32] {
163 assert_eq!(self.len % 4, 0, "buffer length not a multiple of 4");
164 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut f32, self.len / 4) }
166 }
167
168 pub fn as_ptr(&self) -> *const u8 {
170 self.ptr.as_ptr()
171 }
172
173 pub fn as_mut_ptr(&mut self) -> *mut u8 {
174 self.ptr.as_ptr()
175 }
176}
177
178impl Buffer for CpuBuffer {
179 fn as_bytes(&self) -> &[u8] {
180 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
182 }
183
184 fn as_bytes_mut(&mut self) -> &mut [u8] {
185 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
187 }
188
189 fn len(&self) -> usize {
190 self.len
191 }
192
193 fn alignment(&self) -> usize {
194 self.align
195 }
196
197 fn device(&self) -> &str {
198 "cpu"
199 }
200}
201
202impl Drop for CpuBuffer {
203 fn drop(&mut self) {
204 unsafe { alloc::dealloc(self.ptr.as_ptr(), self.layout) }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn zeros_and_read() {
215 let buf = CpuBuffer::zeros(4, DType::F32).unwrap();
216 assert_eq!(buf.len(), 16);
217 assert!(buf.as_bytes().iter().all(|&b| b == 0));
218 }
219
220 #[test]
221 fn from_f32_roundtrip() {
222 let data = vec![1.0f32, 2.0, 3.0, 4.0];
223 let buf = CpuBuffer::from_f32_slice(&data).unwrap();
224 assert_eq!(buf.as_f32_slice(), data.as_slice());
225 }
226
227 #[test]
228 fn alignment_guarantee() {
229 let buf = CpuBuffer::with_capacity(32, 64).unwrap();
230 assert_eq!(buf.as_ptr() as usize % 64, 0);
231 }
232}