Skip to main content

vortex_array/
memory.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Session-scoped memory allocation for host-side buffers.
5
6use std::any::Any;
7use std::fmt::Debug;
8use std::mem::size_of;
9use std::sync::Arc;
10
11use bytes::Bytes;
12use vortex_buffer::Alignment;
13use vortex_buffer::Buffer;
14use vortex_buffer::ByteBuffer;
15use vortex_buffer::ByteBufferMut;
16use vortex_error::VortexResult;
17use vortex_error::vortex_ensure;
18use vortex_error::vortex_err;
19use vortex_session::Ref;
20use vortex_session::RefMut;
21use vortex_session::SessionExt;
22use vortex_session::SessionVar;
23
24/// Mutable host buffer contract used by [`WritableHostBuffer`].
25pub trait HostBufferMut: Send + 'static {
26    /// Returns the logical byte length of the buffer.
27    fn len(&self) -> usize;
28
29    /// Whether the buffer is empty.
30    fn is_empty(&self) -> bool {
31        self.len() == 0
32    }
33
34    /// Returns the alignment of the buffer.
35    fn alignment(&self) -> Alignment;
36
37    /// Returns mutable access to the writable byte range.
38    fn as_mut_slice(&mut self) -> &mut [u8];
39
40    /// Freeze the buffer into an immutable [`ByteBuffer`].
41    fn freeze(self: Box<Self>) -> ByteBuffer;
42}
43
44/// Exact-size writable host buffer returned by a [`HostAllocator`].
45pub struct WritableHostBuffer {
46    inner: Box<dyn HostBufferMut>,
47}
48
49impl WritableHostBuffer {
50    /// Create a writable host buffer from an implementation of [`HostBufferMut`].
51    pub fn new(inner: Box<dyn HostBufferMut>) -> Self {
52        Self { inner }
53    }
54
55    /// Returns the logical byte length of the buffer.
56    pub fn len(&self) -> usize {
57        self.inner.len()
58    }
59
60    /// Returns true when the buffer has zero bytes.
61    pub fn is_empty(&self) -> bool {
62        self.len() == 0
63    }
64
65    /// Returns the alignment of the buffer.
66    pub fn alignment(&self) -> Alignment {
67        self.inner.alignment()
68    }
69
70    /// Returns mutable access to the writable byte range.
71    pub fn as_mut_slice(&mut self) -> &mut [u8] {
72        self.inner.as_mut_slice()
73    }
74
75    /// Returns mutable access to the buffer as a typed slice.
76    pub fn as_mut_slice_typed<T>(&mut self) -> VortexResult<&mut [T]> {
77        vortex_ensure!(
78            size_of::<T>() != 0,
79            InvalidArgument: "Cannot create typed mutable slice for zero-sized type {}",
80            std::any::type_name::<T>()
81        );
82        vortex_ensure!(
83            self.alignment().is_aligned_to(Alignment::of::<T>()),
84            InvalidArgument: "Buffer is not sufficiently aligned for type {}",
85            std::any::type_name::<T>()
86        );
87
88        let bytes = self.as_mut_slice();
89        let byte_len = bytes.len();
90        let ptr = bytes.as_mut_ptr();
91        let type_size = size_of::<T>();
92
93        vortex_ensure!(
94            byte_len.is_multiple_of(type_size),
95            InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
96            type_size,
97            std::any::type_name::<T>()
98        );
99
100        // SAFETY: We checked size divisibility and pointer alignment for `T`,
101        // and we have exclusive mutable access to the underlying bytes.
102        Ok(unsafe { std::slice::from_raw_parts_mut(ptr.cast::<T>(), byte_len / type_size) })
103    }
104
105    /// Freeze the writable buffer into an immutable [`ByteBuffer`].
106    pub fn freeze(self) -> ByteBuffer {
107        self.inner.freeze()
108    }
109
110    /// Freeze the writable buffer into a typed immutable [`Buffer<T>`].
111    pub fn freeze_typed<T>(self) -> VortexResult<Buffer<T>> {
112        vortex_ensure!(
113            size_of::<T>() != 0,
114            InvalidArgument: "Cannot freeze typed buffer for zero-sized type {}",
115            std::any::type_name::<T>()
116        );
117
118        let buffer = self.freeze();
119        let byte_len = buffer.len();
120        let type_size = size_of::<T>();
121        let type_align = Alignment::of::<T>();
122
123        vortex_ensure!(
124            byte_len.is_multiple_of(type_size),
125            InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
126            type_size,
127            std::any::type_name::<T>()
128        );
129        vortex_ensure!(
130            buffer.is_aligned(type_align),
131            InvalidArgument: "Buffer pointer is not aligned to {} for {}",
132            type_align,
133            std::any::type_name::<T>()
134        );
135
136        Ok(Buffer::from_byte_buffer(buffer))
137    }
138}
139
140impl Debug for WritableHostBuffer {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        f.debug_struct("WritableHostBuffer")
143            .field("len", &self.len())
144            .field("alignment", &self.alignment())
145            .finish()
146    }
147}
148
149/// Allocator for exact-size writable host buffers.
150pub trait HostAllocator: Debug + Send + Sync + 'static {
151    /// Allocate a writable host buffer with the requested byte length and alignment.
152    fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer>;
153}
154
155/// Shared allocator reference used throughout session-scoped memory APIs.
156pub type HostAllocatorRef = Arc<dyn HostAllocator>;
157
158/// Extension methods for [`HostAllocator`]s.
159pub trait HostAllocatorExt: HostAllocator {
160    /// Allocate host memory for `len` elements of `T` using `Alignment::of::<T>()`.
161    fn allocate_typed<T>(&self, len: usize) -> VortexResult<WritableHostBuffer> {
162        let bytes = len.checked_mul(size_of::<T>()).ok_or_else(|| {
163            vortex_err!(
164                "Typed host allocation overflow for type {} and len {}",
165                std::any::type_name::<T>(),
166                len
167            )
168        })?;
169        self.allocate(bytes, Alignment::of::<T>())
170    }
171}
172
173impl<A: HostAllocator + ?Sized> HostAllocatorExt for A {}
174
175/// Session-scoped memory configuration for Vortex arrays.
176#[derive(Debug)]
177pub struct MemorySession {
178    allocator: HostAllocatorRef,
179}
180
181impl MemorySession {
182    /// Creates a new session memory configuration using the provided allocator.
183    pub fn new(allocator: HostAllocatorRef) -> Self {
184        Self { allocator }
185    }
186
187    /// Returns the configured allocator.
188    pub fn allocator(&self) -> HostAllocatorRef {
189        Arc::clone(&self.allocator)
190    }
191
192    /// Updates the configured allocator.
193    pub fn set_allocator(&mut self, allocator: HostAllocatorRef) {
194        self.allocator = allocator;
195    }
196}
197
198impl Default for MemorySession {
199    fn default() -> Self {
200        Self::new(Arc::new(DefaultHostAllocator))
201    }
202}
203
204impl SessionVar for MemorySession {
205    fn as_any(&self) -> &dyn Any {
206        self
207    }
208
209    fn as_any_mut(&mut self) -> &mut dyn Any {
210        self
211    }
212}
213
214/// Extension trait for accessing session-scoped memory configuration.
215pub trait MemorySessionExt: SessionExt {
216    /// Returns the memory session for this execution/session context.
217    fn memory(&self) -> Ref<'_, MemorySession> {
218        self.get::<MemorySession>()
219    }
220
221    /// Returns the configured host allocator for this execution/session context.
222    fn allocator(&self) -> HostAllocatorRef {
223        self.memory().allocator()
224    }
225
226    /// Returns mutable access to the memory session.
227    fn memory_mut(&self) -> RefMut<'_, MemorySession> {
228        self.get_mut::<MemorySession>()
229    }
230}
231
232impl<S: SessionExt> MemorySessionExt for S {}
233
234/// Default host allocator.
235#[derive(Debug, Default)]
236pub struct DefaultHostAllocator;
237
238impl HostAllocator for DefaultHostAllocator {
239    fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
240        let mut buffer = ByteBufferMut::with_capacity_aligned(len, alignment);
241        // SAFETY: We fully initialize this slice before freezing it.
242        unsafe { buffer.set_len(len) };
243        Ok(WritableHostBuffer::new(Box::new(
244            DefaultWritableHostBuffer { buffer, alignment },
245        )))
246    }
247}
248
249#[derive(Debug)]
250struct DefaultWritableHostBuffer {
251    buffer: ByteBufferMut,
252    alignment: Alignment,
253}
254
255#[derive(Debug)]
256struct HostBufferOwner {
257    buffer: ByteBufferMut,
258}
259
260impl AsRef<[u8]> for HostBufferOwner {
261    fn as_ref(&self) -> &[u8] {
262        self.buffer.as_slice()
263    }
264}
265
266impl HostBufferMut for DefaultWritableHostBuffer {
267    fn len(&self) -> usize {
268        self.buffer.len()
269    }
270
271    fn alignment(&self) -> Alignment {
272        self.alignment
273    }
274
275    fn as_mut_slice(&mut self) -> &mut [u8] {
276        self.buffer.as_mut_slice()
277    }
278
279    fn freeze(self: Box<Self>) -> ByteBuffer {
280        let Self { buffer, alignment } = *self;
281        let bytes = Bytes::from_owner(HostBufferOwner { buffer });
282        ByteBuffer::from_bytes_aligned(bytes, alignment)
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use std::sync::Arc;
289    use std::sync::atomic::AtomicUsize;
290    use std::sync::atomic::Ordering;
291
292    use super::*;
293
294    #[derive(Debug)]
295    struct CountingAllocator {
296        allocations: Arc<AtomicUsize>,
297    }
298
299    impl HostAllocator for CountingAllocator {
300        fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
301            self.allocations.fetch_add(1, Ordering::Relaxed);
302            DefaultHostAllocator.allocate(len, alignment)
303        }
304    }
305
306    #[test]
307    fn writable_host_buffer_freeze_round_trip() {
308        let allocator = DefaultHostAllocator;
309        let mut writable = allocator.allocate(16, Alignment::new(8)).unwrap();
310        for (idx, byte) in writable.as_mut_slice().iter_mut().enumerate() {
311            *byte = u8::try_from(idx).unwrap();
312        }
313
314        let host = writable.freeze();
315        assert_eq!(host.len(), 16);
316        assert!(host.is_aligned(Alignment::new(8)));
317        assert_eq!(host.as_slice(), (0u8..16).collect::<Vec<_>>().as_slice());
318    }
319
320    #[test]
321    fn memory_session_replaces_allocator() {
322        let allocations = Arc::new(AtomicUsize::new(0));
323        let allocator = Arc::new(CountingAllocator {
324            allocations: Arc::clone(&allocations),
325        });
326        let mut session = MemorySession::default();
327        session.set_allocator(allocator);
328        drop(session.allocator().allocate(4, Alignment::none()).unwrap());
329        assert_eq!(allocations.load(Ordering::Relaxed), 1);
330    }
331
332    #[test]
333    fn typed_allocation_uses_type_alignment() {
334        let allocator = DefaultHostAllocator;
335        let writable = allocator.allocate_typed::<u64>(4).unwrap();
336        assert_eq!(writable.len(), 4 * size_of::<u64>());
337        assert_eq!(writable.alignment(), Alignment::of::<u64>());
338    }
339
340    #[test]
341    fn typed_mut_slice_round_trip() {
342        let allocator = DefaultHostAllocator;
343        let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
344        writable
345            .as_mut_slice_typed::<u64>()
346            .unwrap()
347            .copy_from_slice(&[10, 20, 30, 40]);
348
349        let frozen = writable.freeze();
350        let values = unsafe {
351            std::slice::from_raw_parts(
352                frozen.as_slice().as_ptr().cast::<u64>(),
353                frozen.len() / size_of::<u64>(),
354            )
355        };
356        assert_eq!(values, [10, 20, 30, 40]);
357    }
358
359    #[test]
360    fn typed_mut_slice_rejects_length_mismatch() {
361        let allocator = DefaultHostAllocator;
362        let mut writable = allocator.allocate(7, Alignment::none()).unwrap();
363        assert!(writable.as_mut_slice_typed::<u32>().is_err());
364    }
365
366    #[test]
367    fn freeze_typed_round_trip() {
368        let allocator = DefaultHostAllocator;
369        let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
370        writable
371            .as_mut_slice_typed::<u64>()
372            .unwrap()
373            .copy_from_slice(&[1, 3, 5, 7]);
374
375        let frozen = writable.freeze_typed::<u64>().unwrap();
376        assert_eq!(frozen.as_slice(), [1, 3, 5, 7]);
377    }
378
379    #[test]
380    fn freeze_typed_rejects_length_mismatch() {
381        let allocator = DefaultHostAllocator;
382        let writable = allocator.allocate(7, Alignment::none()).unwrap();
383        let err = writable.freeze_typed::<u32>().unwrap_err();
384        let msg = format!("{err}");
385        assert!(msg.contains("not a multiple of"));
386    }
387}