1use std::any::Any;
7use std::fmt::Debug;
8use std::mem::size_of;
9use std::sync::Arc;
10
11use bytes::Bytes;
12use vortex_buffer::Alignment;
13use vortex_buffer::Buffer;
14use vortex_buffer::ByteBuffer;
15use vortex_buffer::ByteBufferMut;
16use vortex_error::VortexResult;
17use vortex_error::vortex_ensure;
18use vortex_error::vortex_err;
19use vortex_session::Ref;
20use vortex_session::RefMut;
21use vortex_session::SessionExt;
22use vortex_session::SessionVar;
23
24pub trait HostBufferMut: Send + 'static {
26 fn len(&self) -> usize;
28
29 fn is_empty(&self) -> bool {
31 self.len() == 0
32 }
33
34 fn alignment(&self) -> Alignment;
36
37 fn as_mut_slice(&mut self) -> &mut [u8];
39
40 fn freeze(self: Box<Self>) -> ByteBuffer;
42}
43
44pub struct WritableHostBuffer {
46 inner: Box<dyn HostBufferMut>,
47}
48
49impl WritableHostBuffer {
50 pub fn new(inner: Box<dyn HostBufferMut>) -> Self {
52 Self { inner }
53 }
54
55 pub fn len(&self) -> usize {
57 self.inner.len()
58 }
59
60 pub fn is_empty(&self) -> bool {
62 self.len() == 0
63 }
64
65 pub fn alignment(&self) -> Alignment {
67 self.inner.alignment()
68 }
69
70 pub fn as_mut_slice(&mut self) -> &mut [u8] {
72 self.inner.as_mut_slice()
73 }
74
75 pub fn as_mut_slice_typed<T>(&mut self) -> VortexResult<&mut [T]> {
77 vortex_ensure!(
78 size_of::<T>() != 0,
79 InvalidArgument: "Cannot create typed mutable slice for zero-sized type {}",
80 std::any::type_name::<T>()
81 );
82 vortex_ensure!(
83 self.alignment().is_aligned_to(Alignment::of::<T>()),
84 InvalidArgument: "Buffer is not sufficiently aligned for type {}",
85 std::any::type_name::<T>()
86 );
87
88 let bytes = self.as_mut_slice();
89 let byte_len = bytes.len();
90 let ptr = bytes.as_mut_ptr();
91 let type_size = size_of::<T>();
92
93 vortex_ensure!(
94 byte_len.is_multiple_of(type_size),
95 InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
96 type_size,
97 std::any::type_name::<T>()
98 );
99
100 Ok(unsafe { std::slice::from_raw_parts_mut(ptr.cast::<T>(), byte_len / type_size) })
103 }
104
105 pub fn freeze(self) -> ByteBuffer {
107 self.inner.freeze()
108 }
109
110 pub fn freeze_typed<T>(self) -> VortexResult<Buffer<T>> {
112 vortex_ensure!(
113 size_of::<T>() != 0,
114 InvalidArgument: "Cannot freeze typed buffer for zero-sized type {}",
115 std::any::type_name::<T>()
116 );
117
118 let buffer = self.freeze();
119 let byte_len = buffer.len();
120 let type_size = size_of::<T>();
121 let type_align = Alignment::of::<T>();
122
123 vortex_ensure!(
124 byte_len.is_multiple_of(type_size),
125 InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
126 type_size,
127 std::any::type_name::<T>()
128 );
129 vortex_ensure!(
130 buffer.is_aligned(type_align),
131 InvalidArgument: "Buffer pointer is not aligned to {} for {}",
132 type_align,
133 std::any::type_name::<T>()
134 );
135
136 Ok(Buffer::from_byte_buffer(buffer))
137 }
138}
139
140impl Debug for WritableHostBuffer {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 f.debug_struct("WritableHostBuffer")
143 .field("len", &self.len())
144 .field("alignment", &self.alignment())
145 .finish()
146 }
147}
148
149pub trait HostAllocator: Debug + Send + Sync + 'static {
151 fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer>;
153}
154
155pub type HostAllocatorRef = Arc<dyn HostAllocator>;
157
158pub trait HostAllocatorExt: HostAllocator {
160 fn allocate_typed<T>(&self, len: usize) -> VortexResult<WritableHostBuffer> {
162 let bytes = len.checked_mul(size_of::<T>()).ok_or_else(|| {
163 vortex_err!(
164 "Typed host allocation overflow for type {} and len {}",
165 std::any::type_name::<T>(),
166 len
167 )
168 })?;
169 self.allocate(bytes, Alignment::of::<T>())
170 }
171}
172
173impl<A: HostAllocator + ?Sized> HostAllocatorExt for A {}
174
175#[derive(Debug)]
177pub struct MemorySession {
178 allocator: HostAllocatorRef,
179}
180
181impl MemorySession {
182 pub fn new(allocator: HostAllocatorRef) -> Self {
184 Self { allocator }
185 }
186
187 pub fn allocator(&self) -> HostAllocatorRef {
189 Arc::clone(&self.allocator)
190 }
191
192 pub fn set_allocator(&mut self, allocator: HostAllocatorRef) {
194 self.allocator = allocator;
195 }
196}
197
198impl Default for MemorySession {
199 fn default() -> Self {
200 Self::new(Arc::new(DefaultHostAllocator))
201 }
202}
203
204impl SessionVar for MemorySession {
205 fn as_any(&self) -> &dyn Any {
206 self
207 }
208
209 fn as_any_mut(&mut self) -> &mut dyn Any {
210 self
211 }
212}
213
214pub trait MemorySessionExt: SessionExt {
216 fn memory(&self) -> Ref<'_, MemorySession> {
218 self.get::<MemorySession>()
219 }
220
221 fn allocator(&self) -> HostAllocatorRef {
223 self.memory().allocator()
224 }
225
226 fn memory_mut(&self) -> RefMut<'_, MemorySession> {
228 self.get_mut::<MemorySession>()
229 }
230}
231
232impl<S: SessionExt> MemorySessionExt for S {}
233
234#[derive(Debug, Default)]
236pub struct DefaultHostAllocator;
237
238impl HostAllocator for DefaultHostAllocator {
239 fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
240 let mut buffer = ByteBufferMut::with_capacity_aligned(len, alignment);
241 unsafe { buffer.set_len(len) };
243 Ok(WritableHostBuffer::new(Box::new(
244 DefaultWritableHostBuffer { buffer, alignment },
245 )))
246 }
247}
248
249#[derive(Debug)]
250struct DefaultWritableHostBuffer {
251 buffer: ByteBufferMut,
252 alignment: Alignment,
253}
254
255#[derive(Debug)]
256struct HostBufferOwner {
257 buffer: ByteBufferMut,
258}
259
260impl AsRef<[u8]> for HostBufferOwner {
261 fn as_ref(&self) -> &[u8] {
262 self.buffer.as_slice()
263 }
264}
265
266impl HostBufferMut for DefaultWritableHostBuffer {
267 fn len(&self) -> usize {
268 self.buffer.len()
269 }
270
271 fn alignment(&self) -> Alignment {
272 self.alignment
273 }
274
275 fn as_mut_slice(&mut self) -> &mut [u8] {
276 self.buffer.as_mut_slice()
277 }
278
279 fn freeze(self: Box<Self>) -> ByteBuffer {
280 let Self { buffer, alignment } = *self;
281 let bytes = Bytes::from_owner(HostBufferOwner { buffer });
282 ByteBuffer::from_bytes_aligned(bytes, alignment)
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use std::sync::Arc;
289 use std::sync::atomic::AtomicUsize;
290 use std::sync::atomic::Ordering;
291
292 use super::*;
293
294 #[derive(Debug)]
295 struct CountingAllocator {
296 allocations: Arc<AtomicUsize>,
297 }
298
299 impl HostAllocator for CountingAllocator {
300 fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
301 self.allocations.fetch_add(1, Ordering::Relaxed);
302 DefaultHostAllocator.allocate(len, alignment)
303 }
304 }
305
306 #[test]
307 fn writable_host_buffer_freeze_round_trip() {
308 let allocator = DefaultHostAllocator;
309 let mut writable = allocator.allocate(16, Alignment::new(8)).unwrap();
310 for (idx, byte) in writable.as_mut_slice().iter_mut().enumerate() {
311 *byte = u8::try_from(idx).unwrap();
312 }
313
314 let host = writable.freeze();
315 assert_eq!(host.len(), 16);
316 assert!(host.is_aligned(Alignment::new(8)));
317 assert_eq!(host.as_slice(), (0u8..16).collect::<Vec<_>>().as_slice());
318 }
319
320 #[test]
321 fn memory_session_replaces_allocator() {
322 let allocations = Arc::new(AtomicUsize::new(0));
323 let allocator = Arc::new(CountingAllocator {
324 allocations: Arc::clone(&allocations),
325 });
326 let mut session = MemorySession::default();
327 session.set_allocator(allocator);
328 drop(session.allocator().allocate(4, Alignment::none()).unwrap());
329 assert_eq!(allocations.load(Ordering::Relaxed), 1);
330 }
331
332 #[test]
333 fn typed_allocation_uses_type_alignment() {
334 let allocator = DefaultHostAllocator;
335 let writable = allocator.allocate_typed::<u64>(4).unwrap();
336 assert_eq!(writable.len(), 4 * size_of::<u64>());
337 assert_eq!(writable.alignment(), Alignment::of::<u64>());
338 }
339
340 #[test]
341 fn typed_mut_slice_round_trip() {
342 let allocator = DefaultHostAllocator;
343 let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
344 writable
345 .as_mut_slice_typed::<u64>()
346 .unwrap()
347 .copy_from_slice(&[10, 20, 30, 40]);
348
349 let frozen = writable.freeze();
350 let values = unsafe {
351 std::slice::from_raw_parts(
352 frozen.as_slice().as_ptr().cast::<u64>(),
353 frozen.len() / size_of::<u64>(),
354 )
355 };
356 assert_eq!(values, [10, 20, 30, 40]);
357 }
358
359 #[test]
360 fn typed_mut_slice_rejects_length_mismatch() {
361 let allocator = DefaultHostAllocator;
362 let mut writable = allocator.allocate(7, Alignment::none()).unwrap();
363 assert!(writable.as_mut_slice_typed::<u32>().is_err());
364 }
365
366 #[test]
367 fn freeze_typed_round_trip() {
368 let allocator = DefaultHostAllocator;
369 let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
370 writable
371 .as_mut_slice_typed::<u64>()
372 .unwrap()
373 .copy_from_slice(&[1, 3, 5, 7]);
374
375 let frozen = writable.freeze_typed::<u64>().unwrap();
376 assert_eq!(frozen.as_slice(), [1, 3, 5, 7]);
377 }
378
379 #[test]
380 fn freeze_typed_rejects_length_mismatch() {
381 let allocator = DefaultHostAllocator;
382 let writable = allocator.allocate(7, Alignment::none()).unwrap();
383 let err = writable.freeze_typed::<u32>().unwrap_err();
384 let msg = format!("{err}");
385 assert!(msg.contains("not a multiple of"));
386 }
387}