1use std::any::Any;
7use std::fmt::Debug;
8use std::mem::size_of;
9use std::sync::Arc;
10
11use bytes::Bytes;
12use vortex_buffer::Alignment;
13use vortex_buffer::Buffer;
14use vortex_buffer::ByteBuffer;
15use vortex_buffer::ByteBufferMut;
16use vortex_error::VortexResult;
17use vortex_error::vortex_ensure;
18use vortex_error::vortex_err;
19use vortex_session::SessionExt;
20use vortex_session::SessionGuard;
21use vortex_session::SessionVar;
22use vortex_session::VortexSession;
23
24pub trait HostBufferMut: Send + 'static {
26 fn len(&self) -> usize;
28
29 fn is_empty(&self) -> bool {
31 self.len() == 0
32 }
33
34 fn alignment(&self) -> Alignment;
36
37 fn as_mut_slice(&mut self) -> &mut [u8];
39
40 fn freeze(self: Box<Self>) -> ByteBuffer;
42}
43
44pub struct WritableHostBuffer {
46 inner: Box<dyn HostBufferMut>,
47}
48
49impl WritableHostBuffer {
50 pub fn new(inner: Box<dyn HostBufferMut>) -> Self {
52 Self { inner }
53 }
54
55 pub fn len(&self) -> usize {
57 self.inner.len()
58 }
59
60 pub fn is_empty(&self) -> bool {
62 self.len() == 0
63 }
64
65 pub fn alignment(&self) -> Alignment {
67 self.inner.alignment()
68 }
69
70 pub fn as_mut_slice(&mut self) -> &mut [u8] {
72 self.inner.as_mut_slice()
73 }
74
75 pub fn as_mut_slice_typed<T>(&mut self) -> VortexResult<&mut [T]> {
77 vortex_ensure!(
78 size_of::<T>() != 0,
79 InvalidArgument: "Cannot create typed mutable slice for zero-sized type {}",
80 std::any::type_name::<T>()
81 );
82 vortex_ensure!(
83 self.alignment().is_aligned_to(Alignment::of::<T>()),
84 InvalidArgument: "Buffer is not sufficiently aligned for type {}",
85 std::any::type_name::<T>()
86 );
87
88 let bytes = self.as_mut_slice();
89 let byte_len = bytes.len();
90 let ptr = bytes.as_mut_ptr();
91 let type_size = size_of::<T>();
92
93 vortex_ensure!(
94 byte_len.is_multiple_of(type_size),
95 InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
96 type_size,
97 std::any::type_name::<T>()
98 );
99
100 Ok(unsafe { std::slice::from_raw_parts_mut(ptr.cast::<T>(), byte_len / type_size) })
103 }
104
105 pub fn freeze(self) -> ByteBuffer {
107 self.inner.freeze()
108 }
109
110 pub fn freeze_typed<T>(self) -> VortexResult<Buffer<T>> {
112 vortex_ensure!(
113 size_of::<T>() != 0,
114 InvalidArgument: "Cannot freeze typed buffer for zero-sized type {}",
115 std::any::type_name::<T>()
116 );
117
118 let buffer = self.freeze();
119 let byte_len = buffer.len();
120 let type_size = size_of::<T>();
121 let type_align = Alignment::of::<T>();
122
123 vortex_ensure!(
124 byte_len.is_multiple_of(type_size),
125 InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
126 type_size,
127 std::any::type_name::<T>()
128 );
129 vortex_ensure!(
130 buffer.is_aligned(type_align),
131 InvalidArgument: "Buffer pointer is not aligned to {} for {}",
132 type_align,
133 std::any::type_name::<T>()
134 );
135
136 Ok(Buffer::from_byte_buffer(buffer))
137 }
138}
139
140impl Debug for WritableHostBuffer {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 f.debug_struct("WritableHostBuffer")
143 .field("len", &self.len())
144 .field("alignment", &self.alignment())
145 .finish()
146 }
147}
148
149pub trait HostAllocator: Debug + Send + Sync + 'static {
151 fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer>;
153}
154
155pub type HostAllocatorRef = Arc<dyn HostAllocator>;
157
158pub trait HostAllocatorExt: HostAllocator {
160 fn allocate_typed<T>(&self, len: usize) -> VortexResult<WritableHostBuffer> {
162 let bytes = len.checked_mul(size_of::<T>()).ok_or_else(|| {
163 vortex_err!(
164 "Typed host allocation overflow for type {} and len {}",
165 std::any::type_name::<T>(),
166 len
167 )
168 })?;
169 self.allocate(bytes, Alignment::of::<T>())
170 }
171}
172
173impl<A: HostAllocator + ?Sized> HostAllocatorExt for A {}
174
175#[derive(Clone, Debug)]
177pub struct MemorySession {
178 allocator: HostAllocatorRef,
179}
180
181impl MemorySession {
182 pub fn new(allocator: HostAllocatorRef) -> Self {
184 Self { allocator }
185 }
186
187 pub fn allocator(&self) -> HostAllocatorRef {
189 Arc::clone(&self.allocator)
190 }
191
192 pub fn set_allocator(&mut self, allocator: HostAllocatorRef) {
194 self.allocator = allocator;
195 }
196}
197
198impl Default for MemorySession {
199 fn default() -> Self {
200 Self::new(Arc::new(DefaultHostAllocator))
201 }
202}
203
204impl SessionVar for MemorySession {
205 fn as_any(&self) -> &dyn Any {
206 self
207 }
208
209 fn as_any_mut(&mut self) -> &mut dyn Any {
210 self
211 }
212}
213
214pub trait MemorySessionExt: SessionExt {
216 fn memory(&self) -> SessionGuard<'_, MemorySession> {
218 self.get::<MemorySession>()
219 }
220
221 fn allocator(&self) -> HostAllocatorRef {
223 self.memory().allocator()
224 }
225
226 fn with_allocator(self, allocator: HostAllocatorRef) -> VortexSession {
229 let session = self.session();
230 session.get_mut::<MemorySession>().set_allocator(allocator);
231 session
232 }
233}
234
235impl<S: SessionExt> MemorySessionExt for S {}
236
237#[derive(Debug, Default)]
239pub struct DefaultHostAllocator;
240
241impl HostAllocator for DefaultHostAllocator {
242 fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
243 let mut buffer = ByteBufferMut::with_capacity_aligned(len, alignment);
244 unsafe { buffer.set_len(len) };
246 Ok(WritableHostBuffer::new(Box::new(
247 DefaultWritableHostBuffer { buffer, alignment },
248 )))
249 }
250}
251
252#[derive(Debug)]
253struct DefaultWritableHostBuffer {
254 buffer: ByteBufferMut,
255 alignment: Alignment,
256}
257
258#[derive(Debug)]
259struct HostBufferOwner {
260 buffer: ByteBufferMut,
261}
262
263impl AsRef<[u8]> for HostBufferOwner {
264 fn as_ref(&self) -> &[u8] {
265 self.buffer.as_slice()
266 }
267}
268
269impl HostBufferMut for DefaultWritableHostBuffer {
270 fn len(&self) -> usize {
271 self.buffer.len()
272 }
273
274 fn alignment(&self) -> Alignment {
275 self.alignment
276 }
277
278 fn as_mut_slice(&mut self) -> &mut [u8] {
279 self.buffer.as_mut_slice()
280 }
281
282 fn freeze(self: Box<Self>) -> ByteBuffer {
283 let Self { buffer, alignment } = *self;
284 let bytes = Bytes::from_owner(HostBufferOwner { buffer });
285 ByteBuffer::from_bytes_aligned(bytes, alignment)
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use std::sync::Arc;
292 use std::sync::atomic::AtomicUsize;
293 use std::sync::atomic::Ordering;
294
295 use super::*;
296
297 #[derive(Debug)]
298 struct CountingAllocator {
299 allocations: Arc<AtomicUsize>,
300 }
301
302 impl HostAllocator for CountingAllocator {
303 fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
304 self.allocations.fetch_add(1, Ordering::Relaxed);
305 DefaultHostAllocator.allocate(len, alignment)
306 }
307 }
308
309 #[test]
310 fn writable_host_buffer_freeze_round_trip() {
311 let allocator = DefaultHostAllocator;
312 let mut writable = allocator.allocate(16, Alignment::new(8)).unwrap();
313 for (idx, byte) in writable.as_mut_slice().iter_mut().enumerate() {
314 *byte = u8::try_from(idx).unwrap();
315 }
316
317 let host = writable.freeze();
318 assert_eq!(host.len(), 16);
319 assert!(host.is_aligned(Alignment::new(8)));
320 assert_eq!(host.as_slice(), (0u8..16).collect::<Vec<_>>().as_slice());
321 }
322
323 #[test]
324 fn memory_session_replaces_allocator() {
325 let allocations = Arc::new(AtomicUsize::new(0));
326 let allocator = Arc::new(CountingAllocator {
327 allocations: Arc::clone(&allocations),
328 });
329 let mut session = MemorySession::default();
330 session.set_allocator(allocator);
331 drop(session.allocator().allocate(4, Alignment::none()).unwrap());
332 assert_eq!(allocations.load(Ordering::Relaxed), 1);
333 }
334
335 #[test]
336 fn typed_allocation_uses_type_alignment() {
337 let allocator = DefaultHostAllocator;
338 let writable = allocator.allocate_typed::<u64>(4).unwrap();
339 assert_eq!(writable.len(), 4 * size_of::<u64>());
340 assert_eq!(writable.alignment(), Alignment::of::<u64>());
341 }
342
343 #[test]
344 fn typed_mut_slice_round_trip() {
345 let allocator = DefaultHostAllocator;
346 let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
347 writable
348 .as_mut_slice_typed::<u64>()
349 .unwrap()
350 .copy_from_slice(&[10, 20, 30, 40]);
351
352 let frozen = writable.freeze();
353 let values = unsafe {
354 std::slice::from_raw_parts(
355 frozen.as_slice().as_ptr().cast::<u64>(),
356 frozen.len() / size_of::<u64>(),
357 )
358 };
359 assert_eq!(values, [10, 20, 30, 40]);
360 }
361
362 #[test]
363 fn typed_mut_slice_rejects_length_mismatch() {
364 let allocator = DefaultHostAllocator;
365 let mut writable = allocator.allocate(7, Alignment::none()).unwrap();
366 assert!(writable.as_mut_slice_typed::<u32>().is_err());
367 }
368
369 #[test]
370 fn freeze_typed_round_trip() {
371 let allocator = DefaultHostAllocator;
372 let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
373 writable
374 .as_mut_slice_typed::<u64>()
375 .unwrap()
376 .copy_from_slice(&[1, 3, 5, 7]);
377
378 let frozen = writable.freeze_typed::<u64>().unwrap();
379 assert_eq!(frozen.as_slice(), [1, 3, 5, 7]);
380 }
381
382 #[test]
383 fn freeze_typed_rejects_length_mismatch() {
384 let allocator = DefaultHostAllocator;
385 let writable = allocator.allocate(7, Alignment::none()).unwrap();
386 let err = writable.freeze_typed::<u32>().unwrap_err();
387 let msg = format!("{err}");
388 assert!(msg.contains("not a multiple of"));
389 }
390}