Skip to main content

vortex_array/
memory.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Session-scoped memory allocation for host-side buffers.
5
6use std::fmt::Debug;
7use std::mem::size_of;
8use std::sync::Arc;
9
10use bytes::Bytes;
11use vortex_buffer::Alignment;
12use vortex_buffer::Buffer;
13use vortex_buffer::ByteBuffer;
14use vortex_buffer::ByteBufferMut;
15use vortex_error::VortexResult;
16use vortex_error::vortex_ensure;
17use vortex_error::vortex_err;
18use vortex_session::Ref;
19use vortex_session::RefMut;
20use vortex_session::SessionExt;
21
22/// Mutable host buffer contract used by [`WritableHostBuffer`].
23pub trait HostBufferMut: Send + 'static {
24    /// Returns the logical byte length of the buffer.
25    fn len(&self) -> usize;
26
27    /// Whether the buffer is empty.
28    fn is_empty(&self) -> bool {
29        self.len() == 0
30    }
31
32    /// Returns the alignment of the buffer.
33    fn alignment(&self) -> Alignment;
34
35    /// Returns mutable access to the writable byte range.
36    fn as_mut_slice(&mut self) -> &mut [u8];
37
38    /// Freeze the buffer into an immutable [`ByteBuffer`].
39    fn freeze(self: Box<Self>) -> ByteBuffer;
40}
41
42/// Exact-size writable host buffer returned by a [`HostAllocator`].
43pub struct WritableHostBuffer {
44    inner: Box<dyn HostBufferMut>,
45}
46
47impl WritableHostBuffer {
48    /// Create a writable host buffer from an implementation of [`HostBufferMut`].
49    pub fn new(inner: Box<dyn HostBufferMut>) -> Self {
50        Self { inner }
51    }
52
53    /// Returns the logical byte length of the buffer.
54    pub fn len(&self) -> usize {
55        self.inner.len()
56    }
57
58    /// Returns true when the buffer has zero bytes.
59    pub fn is_empty(&self) -> bool {
60        self.len() == 0
61    }
62
63    /// Returns the alignment of the buffer.
64    pub fn alignment(&self) -> Alignment {
65        self.inner.alignment()
66    }
67
68    /// Returns mutable access to the writable byte range.
69    pub fn as_mut_slice(&mut self) -> &mut [u8] {
70        self.inner.as_mut_slice()
71    }
72
73    /// Returns mutable access to the buffer as a typed slice.
74    pub fn as_mut_slice_typed<T>(&mut self) -> VortexResult<&mut [T]> {
75        vortex_ensure!(
76            size_of::<T>() != 0,
77            InvalidArgument: "Cannot create typed mutable slice for zero-sized type {}",
78            std::any::type_name::<T>()
79        );
80        vortex_ensure!(
81            self.alignment().is_aligned_to(Alignment::of::<T>()),
82            InvalidArgument: "Buffer is not sufficiently aligned for type {}",
83            std::any::type_name::<T>()
84        );
85
86        let bytes = self.as_mut_slice();
87        let byte_len = bytes.len();
88        let ptr = bytes.as_mut_ptr();
89        let type_size = size_of::<T>();
90
91        vortex_ensure!(
92            byte_len.is_multiple_of(type_size),
93            InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
94            type_size,
95            std::any::type_name::<T>()
96        );
97
98        // SAFETY: We checked size divisibility and pointer alignment for `T`,
99        // and we have exclusive mutable access to the underlying bytes.
100        Ok(unsafe { std::slice::from_raw_parts_mut(ptr.cast::<T>(), byte_len / type_size) })
101    }
102
103    /// Freeze the writable buffer into an immutable [`ByteBuffer`].
104    pub fn freeze(self) -> ByteBuffer {
105        self.inner.freeze()
106    }
107
108    /// Freeze the writable buffer into a typed immutable [`Buffer<T>`].
109    pub fn freeze_typed<T>(self) -> VortexResult<Buffer<T>> {
110        vortex_ensure!(
111            size_of::<T>() != 0,
112            InvalidArgument: "Cannot freeze typed buffer for zero-sized type {}",
113            std::any::type_name::<T>()
114        );
115
116        let buffer = self.freeze();
117        let byte_len = buffer.len();
118        let type_size = size_of::<T>();
119        let type_align = Alignment::of::<T>();
120
121        vortex_ensure!(
122            byte_len.is_multiple_of(type_size),
123            InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
124            type_size,
125            std::any::type_name::<T>()
126        );
127        vortex_ensure!(
128            buffer.is_aligned(type_align),
129            InvalidArgument: "Buffer pointer is not aligned to {} for {}",
130            type_align,
131            std::any::type_name::<T>()
132        );
133
134        Ok(Buffer::from_byte_buffer(buffer))
135    }
136}
137
138impl Debug for WritableHostBuffer {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        f.debug_struct("WritableHostBuffer")
141            .field("len", &self.len())
142            .field("alignment", &self.alignment())
143            .finish()
144    }
145}
146
147/// Allocator for exact-size writable host buffers.
148pub trait HostAllocator: Debug + Send + Sync + 'static {
149    /// Allocate a writable host buffer with the requested byte length and alignment.
150    fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer>;
151}
152
153/// Shared allocator reference used throughout session-scoped memory APIs.
154pub type HostAllocatorRef = Arc<dyn HostAllocator>;
155
156/// Extension methods for [`HostAllocator`]s.
157pub trait HostAllocatorExt: HostAllocator {
158    /// Allocate host memory for `len` elements of `T` using `Alignment::of::<T>()`.
159    fn allocate_typed<T>(&self, len: usize) -> VortexResult<WritableHostBuffer> {
160        let bytes = len.checked_mul(size_of::<T>()).ok_or_else(|| {
161            vortex_err!(
162                "Typed host allocation overflow for type {} and len {}",
163                std::any::type_name::<T>(),
164                len
165            )
166        })?;
167        self.allocate(bytes, Alignment::of::<T>())
168    }
169}
170
171impl<A: HostAllocator + ?Sized> HostAllocatorExt for A {}
172
173/// Session-scoped memory configuration for Vortex arrays.
174#[derive(Debug)]
175pub struct MemorySession {
176    allocator: HostAllocatorRef,
177}
178
179impl MemorySession {
180    /// Creates a new session memory configuration using the provided allocator.
181    pub fn new(allocator: HostAllocatorRef) -> Self {
182        Self { allocator }
183    }
184
185    /// Returns the configured allocator.
186    pub fn allocator(&self) -> HostAllocatorRef {
187        Arc::clone(&self.allocator)
188    }
189
190    /// Updates the configured allocator.
191    pub fn set_allocator(&mut self, allocator: HostAllocatorRef) {
192        self.allocator = allocator;
193    }
194}
195
196impl Default for MemorySession {
197    fn default() -> Self {
198        Self::new(Arc::new(DefaultHostAllocator))
199    }
200}
201
202/// Extension trait for accessing session-scoped memory configuration.
203pub trait MemorySessionExt: SessionExt {
204    /// Returns the memory session for this execution/session context.
205    fn memory(&self) -> Ref<'_, MemorySession> {
206        self.get::<MemorySession>()
207    }
208
209    /// Returns the configured host allocator for this execution/session context.
210    fn allocator(&self) -> HostAllocatorRef {
211        self.memory().allocator()
212    }
213
214    /// Returns mutable access to the memory session.
215    fn memory_mut(&self) -> RefMut<'_, MemorySession> {
216        self.get_mut::<MemorySession>()
217    }
218}
219
220impl<S: SessionExt> MemorySessionExt for S {}
221
222/// Default host allocator.
223#[derive(Debug, Default)]
224pub struct DefaultHostAllocator;
225
226impl HostAllocator for DefaultHostAllocator {
227    fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
228        let mut buffer = ByteBufferMut::with_capacity_aligned(len, alignment);
229        // SAFETY: We fully initialize this slice before freezing it.
230        unsafe { buffer.set_len(len) };
231        Ok(WritableHostBuffer::new(Box::new(
232            DefaultWritableHostBuffer { buffer, alignment },
233        )))
234    }
235}
236
237#[derive(Debug)]
238struct DefaultWritableHostBuffer {
239    buffer: ByteBufferMut,
240    alignment: Alignment,
241}
242
243#[derive(Debug)]
244struct HostBufferOwner {
245    buffer: ByteBufferMut,
246}
247
248impl AsRef<[u8]> for HostBufferOwner {
249    fn as_ref(&self) -> &[u8] {
250        self.buffer.as_slice()
251    }
252}
253
254impl HostBufferMut for DefaultWritableHostBuffer {
255    fn len(&self) -> usize {
256        self.buffer.len()
257    }
258
259    fn alignment(&self) -> Alignment {
260        self.alignment
261    }
262
263    fn as_mut_slice(&mut self) -> &mut [u8] {
264        self.buffer.as_mut_slice()
265    }
266
267    fn freeze(self: Box<Self>) -> ByteBuffer {
268        let Self { buffer, alignment } = *self;
269        let bytes = Bytes::from_owner(HostBufferOwner { buffer });
270        ByteBuffer::from_bytes_aligned(bytes, alignment)
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use std::sync::Arc;
277    use std::sync::atomic::AtomicUsize;
278    use std::sync::atomic::Ordering;
279
280    use super::*;
281
282    #[derive(Debug)]
283    struct CountingAllocator {
284        allocations: Arc<AtomicUsize>,
285    }
286
287    impl HostAllocator for CountingAllocator {
288        fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
289            self.allocations.fetch_add(1, Ordering::Relaxed);
290            DefaultHostAllocator.allocate(len, alignment)
291        }
292    }
293
294    #[test]
295    fn writable_host_buffer_freeze_round_trip() {
296        let allocator = DefaultHostAllocator;
297        let mut writable = allocator.allocate(16, Alignment::new(8)).unwrap();
298        for (idx, byte) in writable.as_mut_slice().iter_mut().enumerate() {
299            *byte = u8::try_from(idx).unwrap();
300        }
301
302        let host = writable.freeze();
303        assert_eq!(host.len(), 16);
304        assert!(host.is_aligned(Alignment::new(8)));
305        assert_eq!(host.as_slice(), (0u8..16).collect::<Vec<_>>().as_slice());
306    }
307
308    #[test]
309    fn memory_session_replaces_allocator() {
310        let allocations = Arc::new(AtomicUsize::new(0));
311        let allocator = Arc::new(CountingAllocator {
312            allocations: Arc::clone(&allocations),
313        });
314        let mut session = MemorySession::default();
315        session.set_allocator(allocator);
316        drop(session.allocator().allocate(4, Alignment::none()).unwrap());
317        assert_eq!(allocations.load(Ordering::Relaxed), 1);
318    }
319
320    #[test]
321    fn typed_allocation_uses_type_alignment() {
322        let allocator = DefaultHostAllocator;
323        let writable = allocator.allocate_typed::<u64>(4).unwrap();
324        assert_eq!(writable.len(), 4 * size_of::<u64>());
325        assert_eq!(writable.alignment(), Alignment::of::<u64>());
326    }
327
328    #[test]
329    fn typed_mut_slice_round_trip() {
330        let allocator = DefaultHostAllocator;
331        let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
332        writable
333            .as_mut_slice_typed::<u64>()
334            .unwrap()
335            .copy_from_slice(&[10, 20, 30, 40]);
336
337        let frozen = writable.freeze();
338        let values = unsafe {
339            std::slice::from_raw_parts(
340                frozen.as_slice().as_ptr().cast::<u64>(),
341                frozen.len() / size_of::<u64>(),
342            )
343        };
344        assert_eq!(values, [10, 20, 30, 40]);
345    }
346
347    #[test]
348    fn typed_mut_slice_rejects_length_mismatch() {
349        let allocator = DefaultHostAllocator;
350        let mut writable = allocator.allocate(7, Alignment::none()).unwrap();
351        assert!(writable.as_mut_slice_typed::<u32>().is_err());
352    }
353
354    #[test]
355    fn freeze_typed_round_trip() {
356        let allocator = DefaultHostAllocator;
357        let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
358        writable
359            .as_mut_slice_typed::<u64>()
360            .unwrap()
361            .copy_from_slice(&[1, 3, 5, 7]);
362
363        let frozen = writable.freeze_typed::<u64>().unwrap();
364        assert_eq!(frozen.as_slice(), [1, 3, 5, 7]);
365    }
366
367    #[test]
368    fn freeze_typed_rejects_length_mismatch() {
369        let allocator = DefaultHostAllocator;
370        let writable = allocator.allocate(7, Alignment::none()).unwrap();
371        let err = writable.freeze_typed::<u32>().unwrap_err();
372        let msg = format!("{err}");
373        assert!(msg.contains("not a multiple of"));
374    }
375}