Skip to main content

vortex_array/
memory.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Session-scoped memory allocation for host-side buffers.
5
6use std::any::Any;
7use std::fmt::Debug;
8use std::mem::size_of;
9use std::sync::Arc;
10
11use bytes::Bytes;
12use vortex_buffer::Alignment;
13use vortex_buffer::Buffer;
14use vortex_buffer::ByteBuffer;
15use vortex_buffer::ByteBufferMut;
16use vortex_error::VortexResult;
17use vortex_error::vortex_ensure;
18use vortex_error::vortex_err;
19use vortex_session::SessionExt;
20use vortex_session::SessionGuard;
21use vortex_session::SessionVar;
22use vortex_session::VortexSession;
23
24/// Mutable host buffer contract used by [`WritableHostBuffer`].
25pub trait HostBufferMut: Send + 'static {
26    /// Returns the logical byte length of the buffer.
27    fn len(&self) -> usize;
28
29    /// Whether the buffer is empty.
30    fn is_empty(&self) -> bool {
31        self.len() == 0
32    }
33
34    /// Returns the alignment of the buffer.
35    fn alignment(&self) -> Alignment;
36
37    /// Returns mutable access to the writable byte range.
38    fn as_mut_slice(&mut self) -> &mut [u8];
39
40    /// Freeze the buffer into an immutable [`ByteBuffer`].
41    fn freeze(self: Box<Self>) -> ByteBuffer;
42}
43
44/// Exact-size writable host buffer returned by a [`HostAllocator`].
45pub struct WritableHostBuffer {
46    inner: Box<dyn HostBufferMut>,
47}
48
49impl WritableHostBuffer {
50    /// Create a writable host buffer from an implementation of [`HostBufferMut`].
51    pub fn new(inner: Box<dyn HostBufferMut>) -> Self {
52        Self { inner }
53    }
54
55    /// Returns the logical byte length of the buffer.
56    pub fn len(&self) -> usize {
57        self.inner.len()
58    }
59
60    /// Returns true when the buffer has zero bytes.
61    pub fn is_empty(&self) -> bool {
62        self.len() == 0
63    }
64
65    /// Returns the alignment of the buffer.
66    pub fn alignment(&self) -> Alignment {
67        self.inner.alignment()
68    }
69
70    /// Returns mutable access to the writable byte range.
71    pub fn as_mut_slice(&mut self) -> &mut [u8] {
72        self.inner.as_mut_slice()
73    }
74
75    /// Returns mutable access to the buffer as a typed slice.
76    pub fn as_mut_slice_typed<T>(&mut self) -> VortexResult<&mut [T]> {
77        vortex_ensure!(
78            size_of::<T>() != 0,
79            InvalidArgument: "Cannot create typed mutable slice for zero-sized type {}",
80            std::any::type_name::<T>()
81        );
82        vortex_ensure!(
83            self.alignment().is_aligned_to(Alignment::of::<T>()),
84            InvalidArgument: "Buffer is not sufficiently aligned for type {}",
85            std::any::type_name::<T>()
86        );
87
88        let bytes = self.as_mut_slice();
89        let byte_len = bytes.len();
90        let ptr = bytes.as_mut_ptr();
91        let type_size = size_of::<T>();
92
93        vortex_ensure!(
94            byte_len.is_multiple_of(type_size),
95            InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
96            type_size,
97            std::any::type_name::<T>()
98        );
99
100        // SAFETY: We checked size divisibility and pointer alignment for `T`,
101        // and we have exclusive mutable access to the underlying bytes.
102        Ok(unsafe { std::slice::from_raw_parts_mut(ptr.cast::<T>(), byte_len / type_size) })
103    }
104
105    /// Freeze the writable buffer into an immutable [`ByteBuffer`].
106    pub fn freeze(self) -> ByteBuffer {
107        self.inner.freeze()
108    }
109
110    /// Freeze the writable buffer into a typed immutable [`Buffer<T>`].
111    pub fn freeze_typed<T>(self) -> VortexResult<Buffer<T>> {
112        vortex_ensure!(
113            size_of::<T>() != 0,
114            InvalidArgument: "Cannot freeze typed buffer for zero-sized type {}",
115            std::any::type_name::<T>()
116        );
117
118        let buffer = self.freeze();
119        let byte_len = buffer.len();
120        let type_size = size_of::<T>();
121        let type_align = Alignment::of::<T>();
122
123        vortex_ensure!(
124            byte_len.is_multiple_of(type_size),
125            InvalidArgument: "Buffer length {byte_len} is not a multiple of {} for {}",
126            type_size,
127            std::any::type_name::<T>()
128        );
129        vortex_ensure!(
130            buffer.is_aligned(type_align),
131            InvalidArgument: "Buffer pointer is not aligned to {} for {}",
132            type_align,
133            std::any::type_name::<T>()
134        );
135
136        Ok(Buffer::from_byte_buffer(buffer))
137    }
138}
139
140impl Debug for WritableHostBuffer {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        f.debug_struct("WritableHostBuffer")
143            .field("len", &self.len())
144            .field("alignment", &self.alignment())
145            .finish()
146    }
147}
148
149/// Allocator for exact-size writable host buffers.
150pub trait HostAllocator: Debug + Send + Sync + 'static {
151    /// Allocate a writable host buffer with the requested byte length and alignment.
152    fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer>;
153}
154
155/// Shared allocator reference used throughout session-scoped memory APIs.
156pub type HostAllocatorRef = Arc<dyn HostAllocator>;
157
158/// Extension methods for [`HostAllocator`]s.
159pub trait HostAllocatorExt: HostAllocator {
160    /// Allocate host memory for `len` elements of `T` using `Alignment::of::<T>()`.
161    fn allocate_typed<T>(&self, len: usize) -> VortexResult<WritableHostBuffer> {
162        let bytes = len.checked_mul(size_of::<T>()).ok_or_else(|| {
163            vortex_err!(
164                "Typed host allocation overflow for type {} and len {}",
165                std::any::type_name::<T>(),
166                len
167            )
168        })?;
169        self.allocate(bytes, Alignment::of::<T>())
170    }
171}
172
173impl<A: HostAllocator + ?Sized> HostAllocatorExt for A {}
174
175/// Session-scoped memory configuration for Vortex arrays.
176#[derive(Clone, Debug)]
177pub struct MemorySession {
178    allocator: HostAllocatorRef,
179}
180
181impl MemorySession {
182    /// Creates a new session memory configuration using the provided allocator.
183    pub fn new(allocator: HostAllocatorRef) -> Self {
184        Self { allocator }
185    }
186
187    /// Returns the configured allocator.
188    pub fn allocator(&self) -> HostAllocatorRef {
189        Arc::clone(&self.allocator)
190    }
191
192    /// Updates the configured allocator.
193    pub fn set_allocator(&mut self, allocator: HostAllocatorRef) {
194        self.allocator = allocator;
195    }
196}
197
198impl Default for MemorySession {
199    fn default() -> Self {
200        Self::new(Arc::new(DefaultHostAllocator))
201    }
202}
203
204impl SessionVar for MemorySession {
205    fn as_any(&self) -> &dyn Any {
206        self
207    }
208
209    fn as_any_mut(&mut self) -> &mut dyn Any {
210        self
211    }
212}
213
214/// Extension trait for accessing session-scoped memory configuration.
215pub trait MemorySessionExt: SessionExt {
216    /// Returns the memory session for this execution/session context.
217    fn memory(&self) -> SessionGuard<'_, MemorySession> {
218        self.get::<MemorySession>()
219    }
220
221    /// Returns the configured host allocator for this execution/session context.
222    fn allocator(&self) -> HostAllocatorRef {
223        self.memory().allocator()
224    }
225
226    /// Configures the session to use `allocator` as its host allocator, mutating it in place and
227    /// returning it for chaining.
228    fn with_allocator(self, allocator: HostAllocatorRef) -> VortexSession {
229        let session = self.session();
230        session.get_mut::<MemorySession>().set_allocator(allocator);
231        session
232    }
233}
234
235impl<S: SessionExt> MemorySessionExt for S {}
236
237/// Default host allocator.
238#[derive(Debug, Default)]
239pub struct DefaultHostAllocator;
240
241impl HostAllocator for DefaultHostAllocator {
242    fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
243        let mut buffer = ByteBufferMut::with_capacity_aligned(len, alignment);
244        // SAFETY: We fully initialize this slice before freezing it.
245        unsafe { buffer.set_len(len) };
246        Ok(WritableHostBuffer::new(Box::new(
247            DefaultWritableHostBuffer { buffer, alignment },
248        )))
249    }
250}
251
252#[derive(Debug)]
253struct DefaultWritableHostBuffer {
254    buffer: ByteBufferMut,
255    alignment: Alignment,
256}
257
258#[derive(Debug)]
259struct HostBufferOwner {
260    buffer: ByteBufferMut,
261}
262
263impl AsRef<[u8]> for HostBufferOwner {
264    fn as_ref(&self) -> &[u8] {
265        self.buffer.as_slice()
266    }
267}
268
269impl HostBufferMut for DefaultWritableHostBuffer {
270    fn len(&self) -> usize {
271        self.buffer.len()
272    }
273
274    fn alignment(&self) -> Alignment {
275        self.alignment
276    }
277
278    fn as_mut_slice(&mut self) -> &mut [u8] {
279        self.buffer.as_mut_slice()
280    }
281
282    fn freeze(self: Box<Self>) -> ByteBuffer {
283        let Self { buffer, alignment } = *self;
284        let bytes = Bytes::from_owner(HostBufferOwner { buffer });
285        ByteBuffer::from_bytes_aligned(bytes, alignment)
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use std::sync::Arc;
292    use std::sync::atomic::AtomicUsize;
293    use std::sync::atomic::Ordering;
294
295    use super::*;
296
297    #[derive(Debug)]
298    struct CountingAllocator {
299        allocations: Arc<AtomicUsize>,
300    }
301
302    impl HostAllocator for CountingAllocator {
303        fn allocate(&self, len: usize, alignment: Alignment) -> VortexResult<WritableHostBuffer> {
304            self.allocations.fetch_add(1, Ordering::Relaxed);
305            DefaultHostAllocator.allocate(len, alignment)
306        }
307    }
308
309    #[test]
310    fn writable_host_buffer_freeze_round_trip() {
311        let allocator = DefaultHostAllocator;
312        let mut writable = allocator.allocate(16, Alignment::new(8)).unwrap();
313        for (idx, byte) in writable.as_mut_slice().iter_mut().enumerate() {
314            *byte = u8::try_from(idx).unwrap();
315        }
316
317        let host = writable.freeze();
318        assert_eq!(host.len(), 16);
319        assert!(host.is_aligned(Alignment::new(8)));
320        assert_eq!(host.as_slice(), (0u8..16).collect::<Vec<_>>().as_slice());
321    }
322
323    #[test]
324    fn memory_session_replaces_allocator() {
325        let allocations = Arc::new(AtomicUsize::new(0));
326        let allocator = Arc::new(CountingAllocator {
327            allocations: Arc::clone(&allocations),
328        });
329        let mut session = MemorySession::default();
330        session.set_allocator(allocator);
331        drop(session.allocator().allocate(4, Alignment::none()).unwrap());
332        assert_eq!(allocations.load(Ordering::Relaxed), 1);
333    }
334
335    #[test]
336    fn typed_allocation_uses_type_alignment() {
337        let allocator = DefaultHostAllocator;
338        let writable = allocator.allocate_typed::<u64>(4).unwrap();
339        assert_eq!(writable.len(), 4 * size_of::<u64>());
340        assert_eq!(writable.alignment(), Alignment::of::<u64>());
341    }
342
343    #[test]
344    fn typed_mut_slice_round_trip() {
345        let allocator = DefaultHostAllocator;
346        let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
347        writable
348            .as_mut_slice_typed::<u64>()
349            .unwrap()
350            .copy_from_slice(&[10, 20, 30, 40]);
351
352        let frozen = writable.freeze();
353        let values = unsafe {
354            std::slice::from_raw_parts(
355                frozen.as_slice().as_ptr().cast::<u64>(),
356                frozen.len() / size_of::<u64>(),
357            )
358        };
359        assert_eq!(values, [10, 20, 30, 40]);
360    }
361
362    #[test]
363    fn typed_mut_slice_rejects_length_mismatch() {
364        let allocator = DefaultHostAllocator;
365        let mut writable = allocator.allocate(7, Alignment::none()).unwrap();
366        assert!(writable.as_mut_slice_typed::<u32>().is_err());
367    }
368
369    #[test]
370    fn freeze_typed_round_trip() {
371        let allocator = DefaultHostAllocator;
372        let mut writable = allocator.allocate_typed::<u64>(4).unwrap();
373        writable
374            .as_mut_slice_typed::<u64>()
375            .unwrap()
376            .copy_from_slice(&[1, 3, 5, 7]);
377
378        let frozen = writable.freeze_typed::<u64>().unwrap();
379        assert_eq!(frozen.as_slice(), [1, 3, 5, 7]);
380    }
381
382    #[test]
383    fn freeze_typed_rejects_length_mismatch() {
384        let allocator = DefaultHostAllocator;
385        let writable = allocator.allocate(7, Alignment::none()).unwrap();
386        let err = writable.freeze_typed::<u32>().unwrap_err();
387        let msg = format!("{err}");
388        assert!(msg.contains("not a multiple of"));
389    }
390}