sized_dst/
lib.rs

1#![doc = include_str!("../README.md")]
2//!
3//!
4//! If you to change the alignment requirements of your DST (for example, your type may need to be
5//! aligned to a 32-byte boundary), see [`DstBase`], which is also where most of the documentation
6//! lives.
7
8#![cfg_attr(not(any(test, feature = "std")), no_std)]
9#![feature(ptr_metadata, unsize, pin_deref_mut)]
10#![warn(missing_docs)]
11
12mod assert;
13mod trait_impls;
14
15use core::{
16    any::Any,
17    marker::{PhantomData, Unsize},
18    mem::{size_of, MaybeUninit},
19    ops::{Deref, DerefMut},
20    ptr::{copy_nonoverlapping, drop_in_place, from_raw_parts, from_raw_parts_mut, Pointee},
21};
22
23use aligned::Aligned;
24
25pub use aligned::{Alignment, A1, A16, A2, A32, A4, A64, A8};
26
27/// Given multiple type names, return the size of the biggest type.
28///
29/// This can be used in const context. It is intended for computing the required size for [`Dst`].
30///
31/// ```
32/// # use core::fmt::Display;
33/// # use core::mem::{size_of, size_of_val};
34/// # use sized_dst::{max_size, Dst};
35/// // Dst data will be the size of f64
36/// let dst = Dst::<dyn Display, { max_size!(u32, f64, bool) }>::new(12.0);
37/// assert!(size_of_val(&dst) > size_of::<f64>());
38/// ```
39#[macro_export]
40macro_rules! max_size {
41    ($first:ty $(, $other:ty)* $(,)?) => {{
42        #[allow(unused_mut)]
43        let mut max = core::mem::size_of::<$first>();
44        $(
45            let next = core::mem::size_of::<$other>();
46            if next > max {
47                max = next;
48            }
49        )*
50        max
51    }};
52}
53
54/// Sized object that stores a DST object, such as a trait object, on the stack.
55///
56/// The layout of `DstBase` consists of the DST metadata (for trait objects, this is the vtable
57/// pointer) and a fixed block of memory storing the actual object.
58///
59/// `DstBase` implements `Deref` and `DerefMut` for the DST, so it can be used in place of the DST
60/// in most use-cases.
61///
62/// ```
63/// # use sized_dst::Dst;
64/// let dst = Dst::<dyn ToString, 8>::new(12u32);
65/// assert_eq!(dst.to_string(), "12");
66/// ```
67///
68/// Rather than using `DstBase` directly, use [`Dst`], which is aligned to the target word boundary.
69/// This is almost always what you want.
70///
71/// # Alignment and Capacity
72///
73/// The alignment and capacity of the backing storage are specified via the generic parameters `A`
74/// and `N` respectively. Since DST objects can have any size and alignment, the object being
75/// stored must fit and be well-aligned within the backing storage. These checks are done at
76/// **compile-time**, resulting in a compile-time error if the size or alignment requirements are
77/// not met.
78///
79/// For example, the following will work:
80/// ```
81/// # use sized_dst::{DstBase, A4};
82/// // u32 fits within 8 bytes and an alignment of 4
83/// let dst = DstBase::<dyn std::fmt::Debug, A4, 8>::new(12u32);
84/// ```
85/// however, this will fail to compile due to insufficient capacity:
86/// ```compile_fail
87/// # use sized_dst::{DstBase, A4};
88/// // [u32; 3] does not fit within 8 bytes
89/// let dst = DstBase::<dyn std::fmt::Debug, A4, 8>::new([1u32, 23u32, 0u32]);
90/// ```
91/// and this will also fail to compile due to insufficient alignment:
92/// ```compile_fail
93/// # use sized_dst::{DstBase, A4};
94/// // f64 does not fit the alignment requirement of 4 bytes
95/// let dst = DstBase::<dyn std::fmt::Debug, A4, 8>::new(12.0f64);
96/// ```
97pub struct DstBase<D: ?Sized + Pointee, A: Alignment, const N: usize> {
98    metadata: <D as Pointee>::Metadata,
99    obj_bytes: Aligned<A, [MaybeUninit<u8>; N]>,
100    // Technically we own an instance of D, so we need this for autotraits to be propagated
101    // correctly
102    _phantom: PhantomData<D>,
103}
104
105impl<D: ?Sized, A: Alignment, const N: usize> DstBase<D, A, N> {
106    /// Create a new `DstBase` from a sized `value`.
107    ///
108    /// `value` is coerced from its original type into the type D and stored into the `DstBase`.
109    /// The size and alignment of `value` are checked against the `DstBase` parameters at
110    /// **compile-time**, resulting in a compile error if `value` doesn't fit.
111    pub fn new<T: Unsize<D>>(value: T) -> Self {
112        assert::check_size_and_align_of_dst::<T, A, N>();
113
114        // SAFETY:
115        // - `val_size` is the size of `value`, as expected by the function.
116        // - Our assertions made sure that `value` fits in `self.obj_bytes`, and its alignment
117        //   requirement does not extracted that of `self.obj_bytes`.
118        // - We call `mem::forget` immediately after to prevent double-free.
119        let out = unsafe { Self::from_dyn(&value, size_of::<T>()) };
120        core::mem::forget(value);
121        out
122    }
123
124    /// SAFETY:
125    /// - `val_size` must be the size of the object `value` points to.
126    /// - `self.obj_bytes` must be at least `val_size` bytes long.
127    /// - `value`'s alignment requirements must not be more strict than `self.obj_bytes`.
128    /// - `mem::forget` must be called on the `value` object after this call.
129    unsafe fn from_dyn(value: &D, val_size: usize) -> Self {
130        // The metadata comes from a fat D pointer pointing to `value`. We can use the metadata
131        // to reconstruct the fat D pointer in the future.
132        let metadata = core::ptr::metadata(value as *const D);
133
134        let mut obj_bytes = Aligned([MaybeUninit::uninit(); N]);
135        // Move `value` into `obj_bytes`
136        //
137        // SAFETY:
138        // - `value` and `self.obj_bytes` are at least `val_size` bytes, so the copy is valid.
139        // - `value`'s alignment is not more strict than `self.obj_bytes`, so the copied `value`
140        //   will always be well-aligned.
141        // - `value` and `obj_bytes` are separate variables, so they can't overlap.
142        unsafe {
143            copy_nonoverlapping(
144                value as *const D as *const MaybeUninit<u8>,
145                obj_bytes.as_mut_ptr(),
146                val_size,
147            )
148        };
149
150        DstBase {
151            metadata,
152            obj_bytes,
153            _phantom: PhantomData,
154        }
155    }
156
157    /// Get a dereferenceable, well-aligned pointer to the stored DST object
158    fn as_ptr(&self) -> *const D {
159        // A value that coerces into `D` was written into `obj_bytes` in the constructor, so the
160        // pointer to `obj_bytes` always points to a valid, well-aligned instance of `D`.
161        // Additionally, `metadata` was extracted from a fat D pointer in the constructor . As a
162        // result, the reconstructed D pointer is guaranteed to be dereferenceable.
163        from_raw_parts(self.obj_bytes.as_ptr(), self.metadata)
164    }
165
166    /// Get a dereferenceable, well-aligned mutable pointer to the stored DST object
167    fn as_mut_ptr(&mut self) -> *mut D {
168        // See `as_ptr` for how the API guarantees are upholded
169        from_raw_parts_mut(self.obj_bytes.as_mut_ptr(), self.metadata)
170    }
171}
172
173macro_rules! downcast_impl {
174    ($dst:ty) => {
175        impl<A: Alignment, const N: usize> DstBase<$dst, A, N> {
176            /// Attempt to downcast to a concrete type
177            pub fn downcast<T: Any>(self) -> Option<T> {
178                if let Some(val_ref) = self.deref().downcast_ref() {
179                    // SAFETY:
180                    // - val_ref is a valid reference to T, so we're reading a valid value of T for sure.
181                    // - Call mem::forget on self so we don't drop T twice.
182                    let val = unsafe { core::ptr::read(val_ref as *const T) };
183                    core::mem::forget(self);
184                    Some(val)
185                } else {
186                    None
187                }
188            }
189        }
190    };
191}
192downcast_impl!(dyn Any);
193downcast_impl!(dyn Any + Send);
194downcast_impl!(dyn Any + Send + Sync);
195
196impl<D: ?Sized, A: Alignment, const N: usize> Drop for DstBase<D, A, N> {
197    fn drop(&mut self) {
198        // SAFETY:
199        // - `as_mut_ptr` is guaranteed to return a dereferenceable, well-aligned pointer.
200        // - The stored value has not been dropped previously, since `forget` was called in the
201        //   constructor.
202        unsafe { drop_in_place(self.as_mut_ptr()) }
203    }
204}
205
206impl<D: ?Sized, A: Alignment, const N: usize> Deref for DstBase<D, A, N> {
207    type Target = D;
208
209    fn deref(&self) -> &Self::Target {
210        // SAFETY:
211        // - `as_ptr` is guaranteed to return a dereferenceable, well-aligned pointer.
212        // - Lifetime of the return reference is constrained by the lifetime of `self`, so the
213        //   reference will never dangle.
214        unsafe { &*self.as_ptr() }
215    }
216}
217
218impl<D: ?Sized, A: Alignment, const N: usize> DerefMut for DstBase<D, A, N> {
219    fn deref_mut(&mut self) -> &mut Self::Target {
220        // SAFETY:
221        // - `as_mut_ptr` is guaranteed to return a dereferenceable, well-aligned pointer.
222        // - Lifetime of the return reference is constrained by the lifetime of `self`, so the
223        //   reference will never dangle.
224        unsafe { &mut *self.as_mut_ptr() }
225    }
226}
227
228/// [`DstBase`] storing an object with alignment of 1 byte
229pub type DstA1<D, const N: usize> = DstBase<D, A1, N>;
230/// [`DstBase`] storing an object with alignment of 2 bytes
231pub type DstA2<D, const N: usize> = DstBase<D, A2, N>;
232/// [`DstBase`] storing an object with alignment of 4 bytes
233pub type DstA4<D, const N: usize> = DstBase<D, A4, N>;
234/// [`DstBase`] storing an object with alignment of 8 bytes
235pub type DstA8<D, const N: usize> = DstBase<D, A8, N>;
236/// [`DstBase`] storing an object with alignment of 16 bytes
237pub type DstA16<D, const N: usize> = DstBase<D, A16, N>;
238/// [`DstBase`] storing an object with alignment of 32 bytes
239pub type DstA32<D, const N: usize> = DstBase<D, A32, N>;
240/// [`DstBase`] storing an object with alignment of 64 bytes
241pub type DstA64<D, const N: usize> = DstBase<D, A64, N>;
242
243#[cfg(target_pointer_width = "16")]
244/// [`DstBase`] aligned to the target word boundary. This is almost always what you want to use.
245pub type Dst<D, const N: usize> = DstA2<D, N>;
246#[cfg(target_pointer_width = "32")]
247/// [`DstBase`] aligned to the target word boundary. This is almost always what you want to use.
248pub type Dst<D, const N: usize> = DstA4<D, N>;
249#[cfg(target_pointer_width = "64")]
250/// [`DstBase`] with the alignment of `usize`. This is almost always what you want to use.
251pub type Dst<D, const N: usize> = DstA8<D, N>;
252
253/// [`Dst`] with the size and alignment of `usize`. This can be used to represent both `dyn`
254/// pointers and _`dyn` objects smaller than a pointer_.
255///
256/// ```
257/// # use std::fmt::Debug;
258/// # use sized_dst::DstPtr;
259/// fn print(ptr: DstPtr<dyn Debug + '_>) {
260///     println!("{ptr:?}");
261/// }
262///
263/// print(DstPtr::new(1u8));
264/// print(DstPtr::new(&1u8));
265/// print(DstPtr::new(&mut 1u8));
266/// print(DstPtr::new(Box::new(1u8)));
267/// ```
268pub type DstPtr<D> = Dst<D, { size_of::<usize>() }>;
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[allow(clippy::needless_borrows_for_generic_args)]
275    #[test]
276    fn to_string() {
277        let n = 123;
278        let mut obj = DstA8::<dyn std::fmt::Display, 8>::new(4);
279        assert_eq!(obj.to_string(), "4");
280
281        obj = DstA8::new('a');
282        assert_eq!(obj.to_string(), "a");
283
284        obj = DstA8::new(123u64);
285        assert_eq!(obj.to_string(), "123");
286
287        // This is safe, because n outlives obj
288        obj = DstA8::new(&n);
289        assert_eq!(obj.to_string(), "123");
290
291        assert_eq!(align_of_val(&obj.obj_bytes), 8);
292        assert!(size_of_val(&obj.obj_bytes) >= 8);
293    }
294
295    #[test]
296    fn small() {
297        let mut obj = DstA1::<dyn std::fmt::Debug, 2>::new(3u8);
298        assert_eq!(align_of_val(&obj.obj_bytes), 1);
299        assert!(size_of_val(&obj.obj_bytes) >= 2);
300
301        obj = DstA1::new([1u8, 2u8]);
302        assert_eq!(align_of_val(&obj.obj_bytes), 1);
303        assert!(size_of_val(&obj.obj_bytes) >= 2);
304    }
305
306    #[allow(clippy::needless_borrows_for_generic_args)]
307    #[test]
308    fn native() {
309        let mut obj = Dst::<dyn std::fmt::Debug, 16>::new(0usize);
310        assert_eq!(align_of_val(&obj.obj_bytes), size_of::<usize>());
311        assert!(size_of_val(&obj.obj_bytes) >= 16);
312
313        obj = Dst::new(std::ptr::null::<*const String>());
314        assert_eq!(align_of_val(&obj.obj_bytes), size_of::<usize>());
315        assert!(size_of_val(&obj.obj_bytes) >= 16);
316
317        obj = Dst::new(Box::new(32));
318        assert_eq!(align_of_val(&obj.obj_bytes), size_of::<usize>());
319        assert!(size_of_val(&obj.obj_bytes) >= 16);
320
321        obj = Dst::new(&0);
322        assert_eq!(align_of_val(&obj.obj_bytes), size_of::<usize>());
323        assert!(size_of_val(&obj.obj_bytes) >= 16);
324    }
325
326    #[test]
327    fn custom_trait_obj() {
328        struct Test<'a> {
329            bop_count: &'a mut u32,
330            drop_count: &'a mut u32,
331        }
332        trait Bop {
333            fn bop(&mut self);
334        }
335        impl<'a> Bop for Test<'a> {
336            fn bop(&mut self) {
337                *self.bop_count += 1;
338            }
339        }
340        impl<'a> Drop for Test<'a> {
341            fn drop(&mut self) {
342                *self.drop_count += 1;
343            }
344        }
345
346        let mut bop_count = 0;
347        let mut drop_count = 0;
348        let test = Test {
349            bop_count: &mut bop_count,
350            drop_count: &mut drop_count,
351        };
352        let mut obj = Dst::<dyn Bop, 20>::new(test);
353        obj.bop();
354        obj.bop();
355        drop(obj);
356
357        // We bopped twice
358        assert_eq!(bop_count, 2);
359        // Should have only dropped once
360        assert_eq!(drop_count, 1);
361    }
362
363    #[test]
364    fn slice() {
365        let mut obj = DstA1::<[u8], 4>::new([b'a', b'b']);
366        assert_eq!(obj.deref(), b"ab");
367
368        obj = DstA1::<[u8], 4>::new([b'a', b'b', b'c', b'd']);
369        assert_eq!(obj.deref(), b"abcd");
370
371        obj = DstA1::<[u8], 4>::new([]);
372        assert_eq!(obj.deref(), b"");
373    }
374
375    #[test]
376    fn align32() {
377        let obj = DstA32::<dyn std::fmt::Debug, 32>::new(aligned::Aligned::<A32, _>(0));
378        assert_eq!(align_of_val(&obj.obj_bytes), 32);
379    }
380
381    #[test]
382    fn any_replace() {
383        let mut obj = Dst::<dyn Any, 32>::new(String::from("xyz"));
384        let ref_mut = obj.downcast_mut::<String>().unwrap();
385        assert_eq!(ref_mut, "xyz");
386
387        // Use a downcasted reference to replace the inner object without changing the metadata.
388        // The metadata should still be valid because the concrete type of the replacement object
389        // is the exact same as the original object, so future from_raw_parts calls are still sound.
390        *ref_mut = String::from("abc");
391        assert_eq!(obj.downcast_ref::<String>().unwrap(), "abc");
392    }
393
394    #[test]
395    fn downcast() {
396        let obj = Dst::<dyn Any, 32>::new(Box::new(2u32));
397        let val: Box<u32> = obj.downcast::<Box<u32>>().unwrap();
398        assert_eq!(*val, 2);
399
400        let obj = Dst::<dyn Any, 32>::new(Box::new(2u32));
401        assert!(obj.downcast::<String>().is_none());
402    }
403
404    #[test]
405    fn max_size() {
406        assert_eq!(size_of::<[u8; max_size!(String)]>(), size_of::<String>());
407        assert_eq!(size_of::<[u8; max_size!(u8, u32, u64)]>(), size_of::<u64>());
408    }
409
410    #[test]
411    fn custom_dst() {
412        struct CustomDyn<T: ?Sized> {
413            label: u32,
414            data: T,
415        }
416
417        let val = CustomDyn {
418            label: 1,
419            data: Box::new(12u64),
420        };
421        let mut obj = Dst::<CustomDyn<dyn ToString>, 16>::new(val);
422        assert_eq!(obj.label, 1);
423        obj.label = 10;
424        assert_eq!(obj.label, 10);
425        assert_eq!(obj.data.to_string(), "12");
426    }
427
428    #[allow(clippy::needless_borrows_for_generic_args)]
429    #[test]
430    fn dst_ptr() {
431        fn assert(ptr: DstPtr<dyn ToString + '_>) {
432            assert_eq!(ptr.to_string(), "1");
433        }
434
435        assert(DstPtr::new(1u8));
436        assert(DstPtr::new(&1u32));
437        assert(DstPtr::new(&mut 1u32));
438        assert(DstPtr::new(Box::new(1u32)));
439    }
440}