1use std::fmt::Debug;
7use std::mem::size_of;
8use std::sync::Arc;
9
10use bytes::Bytes;
11use vortex_buffer::Alignment;
12use vortex_buffer::Buffer;
13use vortex_buffer::ByteBuffer;
14use vortex_buffer::ByteBufferMut;
15use vortex_error::VortexResult;
16use vortex_error::vortex_ensure;
17use vortex_error::vortex_err;
18use vortex_session::Ref;
19use vortex_session::RefMut;
20use vortex_session::SessionExt;
21
22pub trait HostBufferMut: Send + 'static {
24 fn len(&self) -> usize;
26
27 fn is_empty(&self) -> bool {
29 self.len() == 0
30 }
31
32 fn alignment(&self) -> Alignment;
34
35 fn as_mut_slice(&mut self) -> &mut [u8];
37
38 fn freeze(self: Box<Self>) -> ByteBuffer;
40}
41
42pub struct WritableHostBuffer {
44 inner: Box<dyn HostBufferMut>,
45}
46
47impl WritableHostBuffer {
48 pub fn new(inner: Box<dyn HostBufferMut>) -> Self {
50 Self { inner }
51 }
52
53 pub fn len(&self) -> usize {
55 self.inner.len()
56 }
57
58 pub fn is_empty(&self) -> bool {
60 self.len() == 0
61 }
62
63 pub fn alignment(&self) -> Alignment {
65 self.inner.alignment()
66 }
67
68 pub fn as_mut_slice(&mut self) -> &mut [u8] {
70 self.inner.as_mut_slice()
71 }
72
73 pub fn as_mut_slice_typed<T>(&mut self) -> VortexResult<&mut [T]> {
75 vortex_ensure!(
76 size_of::<T>() != 0,
77 InvalidArgument: "Cannot create typed mutable slice for zero-sized type {}",
78 std::any::type_name::<T>()
79 );
80 vortex_ensure!(
81 self.alignment().is_aligned_to(Alignment::of::<T>()),
82 InvalidArgument: "Buffer is not sufficiently aligned for type {}",
83 std::any::type_name::<T>()
84 );
85
86 let bytes = self.as_mut_slice();
87 let byte_len = bytes.len();
88 let ptr = bytes.as_mut_ptr();
89 let type_size = size_of::<T>();
90
91 vortex_ensure!(
92 byte_len.is_multiple_of(type_size),
93 InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
94 type_size,
95 std::any::type_name::<T>()
96 );
97
98 Ok(unsafe { std::slice::from_raw_parts_mut(ptr.cast::<T>(), byte_len / type_size) })
101 }
102
103 pub fn freeze(self) -> ByteBuffer {
105 self.inner.freeze()
106 }
107
108 pub fn freeze_typed<T>(self) -> VortexResult<Buffer<T>> {
110 vortex_ensure!(
111 size_of::<T>() != 0,
112 InvalidArgument: "Cannot freeze typed buffer for zero-sized type {}",
113 std::any::type_name::<T>()
114 );
115
116 let buffer = self.freeze();
117 let byte_len = buffer.len();
118 let type_size = size_of::<T>();
119 let type_align = Alignment::of::<T>();
120
121 vortex_ensure!(
122 byte_len.is_multiple_of(type_size),
123 InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
124 type_size,
125 std::any::type_name::<T>()
126 );
127 vortex_ensure!(
128 buffer.is_aligned(type_align),
129 InvalidArgument: "Buffer pointer is not aligned to {} for {}",
130 type_align,
131 std::any::type_name::<T>()
132 );
133
134 Ok(Buffer::from_byte_buffer(buffer))
135 }
136}
137
138impl Debug for WritableHostBuffer {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 f.debug_struct("WritableHostBuffer")
141 .field("len", &self.len())
142 .field("alignment", &self.alignment())
143 .finish()
144 }
145}
146
147pub trait HostAllocator: Debug + Send + Sync + 'static {
149 fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer>;
151}
152
153pub type HostAllocatorRef = Arc<dyn HostAllocator>;
155
156pub trait HostAllocatorExt: HostAllocator {
158 fn allocate_typed<T>(&self, len: usize) -> VortexResult<WritableHostBuffer> {
160 let bytes = len.checked_mul(size_of::<T>()).ok_or_else(|| {
161 vortex_err!(
162 "Typed host allocation overflow for type {} and len {}",
163 std::any::type_name::<T>(),
164 len
165 )
166 })?;
167 self.allocate(bytes, Alignment::of::<T>())
168 }
169}
170
171impl<A: HostAllocator + ?Sized> HostAllocatorExt for A {}
172
173#[derive(Debug)]
175pub struct MemorySession {
176 allocator: HostAllocatorRef,
177}
178
179impl MemorySession {
180 pub fn new(allocator: HostAllocatorRef) -> Self {
182 Self { allocator }
183 }
184
185 pub fn allocator(&self) -> HostAllocatorRef {
187 Arc::clone(&self.allocator)
188 }
189
190 pub fn set_allocator(&mut self, allocator: HostAllocatorRef) {
192 self.allocator = allocator;
193 }
194}
195
196impl Default for MemorySession {
197 fn default() -> Self {
198 Self::new(Arc::new(DefaultHostAllocator))
199 }
200}
201
202pub trait MemorySessionExt: SessionExt {
204 fn memory(&self) -> Ref<'_, MemorySession> {
206 self.get::<MemorySession>()
207 }
208
209 fn allocator(&self) -> HostAllocatorRef {
211 self.memory().allocator()
212 }
213
214 fn memory_mut(&self) -> RefMut<'_, MemorySession> {
216 self.get_mut::<MemorySession>()
217 }
218}
219
220impl<S: SessionExt> MemorySessionExt for S {}
221
222#[derive(Debug, Default)]
224pub struct DefaultHostAllocator;
225
226impl HostAllocator for DefaultHostAllocator {
227 fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
228 let mut buffer = ByteBufferMut::with_capacity_aligned(len, alignment);
229 unsafe { buffer.set_len(len) };
231 Ok(WritableHostBuffer::new(Box::new(
232 DefaultWritableHostBuffer { buffer, alignment },
233 )))
234 }
235}
236
237#[derive(Debug)]
238struct DefaultWritableHostBuffer {
239 buffer: ByteBufferMut,
240 alignment: Alignment,
241}
242
243#[derive(Debug)]
244struct HostBufferOwner {
245 buffer: ByteBufferMut,
246}
247
248impl AsRef<[u8]> for HostBufferOwner {
249 fn as_ref(&self) -> &[u8] {
250 self.buffer.as_slice()
251 }
252}
253
254impl HostBufferMut for DefaultWritableHostBuffer {
255 fn len(&self) -> usize {
256 self.buffer.len()
257 }
258
259 fn alignment(&self) -> Alignment {
260 self.alignment
261 }
262
263 fn as_mut_slice(&mut self) -> &mut [u8] {
264 self.buffer.as_mut_slice()
265 }
266
267 fn freeze(self: Box<Self>) -> ByteBuffer {
268 let Self { buffer, alignment } = *self;
269 let bytes = Bytes::from_owner(HostBufferOwner { buffer });
270 ByteBuffer::from_bytes_aligned(bytes, alignment)
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use std::sync::Arc;
277 use std::sync::atomic::AtomicUsize;
278 use std::sync::atomic::Ordering;
279
280 use super::*;
281
282 #[derive(Debug)]
283 struct CountingAllocator {
284 allocations: Arc<AtomicUsize>,
285 }
286
287 impl HostAllocator for CountingAllocator {
288 fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
289 self.allocations.fetch_add(1, Ordering::Relaxed);
290 DefaultHostAllocator.allocate(len, alignment)
291 }
292 }
293
294 #[test]
295 fn writable_host_buffer_freeze_round_trip() {
296 let allocator = DefaultHostAllocator;
297 let mut writable = allocator.allocate(16, Alignment::new(8)).unwrap();
298 for (idx, byte) in writable.as_mut_slice().iter_mut().enumerate() {
299 *byte = u8::try_from(idx).unwrap();
300 }
301
302 let host = writable.freeze();
303 assert_eq!(host.len(), 16);
304 assert!(host.is_aligned(Alignment::new(8)));
305 assert_eq!(host.as_slice(), (0u8..16).collect::<Vec<_>>().as_slice());
306 }
307
308 #[test]
309 fn memory_session_replaces_allocator() {
310 let allocations = Arc::new(AtomicUsize::new(0));
311 let allocator = Arc::new(CountingAllocator {
312 allocations: Arc::clone(&allocations),
313 });
314 let mut session = MemorySession::default();
315 session.set_allocator(allocator);
316 drop(session.allocator().allocate(4, Alignment::none()).unwrap());
317 assert_eq!(allocations.load(Ordering::Relaxed), 1);
318 }
319
320 #[test]
321 fn typed_allocation_uses_type_alignment() {
322 let allocator = DefaultHostAllocator;
323 let writable = allocator.allocate_typed::<u64>(4).unwrap();
324 assert_eq!(writable.len(), 4 * size_of::<u64>());
325 assert_eq!(writable.alignment(), Alignment::of::<u64>());
326 }
327
328 #[test]
329 fn typed_mut_slice_round_trip() {
330 let allocator = DefaultHostAllocator;
331 let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
332 writable
333 .as_mut_slice_typed::<u64>()
334 .unwrap()
335 .copy_from_slice(&[10, 20, 30, 40]);
336
337 let frozen = writable.freeze();
338 let values = unsafe {
339 std::slice::from_raw_parts(
340 frozen.as_slice().as_ptr().cast::<u64>(),
341 frozen.len() / size_of::<u64>(),
342 )
343 };
344 assert_eq!(values, [10, 20, 30, 40]);
345 }
346
347 #[test]
348 fn typed_mut_slice_rejects_length_mismatch() {
349 let allocator = DefaultHostAllocator;
350 let mut writable = allocator.allocate(7, Alignment::none()).unwrap();
351 assert!(writable.as_mut_slice_typed::<u32>().is_err());
352 }
353
354 #[test]
355 fn freeze_typed_round_trip() {
356 let allocator = DefaultHostAllocator;
357 let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
358 writable
359 .as_mut_slice_typed::<u64>()
360 .unwrap()
361 .copy_from_slice(&[1, 3, 5, 7]);
362
363 let frozen = writable.freeze_typed::<u64>().unwrap();
364 assert_eq!(frozen.as_slice(), [1, 3, 5, 7]);
365 }
366
367 #[test]
368 fn freeze_typed_rejects_length_mismatch() {
369 let allocator = DefaultHostAllocator;
370 let writable = allocator.allocate(7, Alignment::none()).unwrap();
371 let err = writable.freeze_typed::<u32>().unwrap_err();
372 let msg = format!("{err}");
373 assert!(msg.contains("not a multiple of"));
374 }
375}