shapely_core/
partial.rs

1use crate::{trace, FieldError, ShapeDesc, Shapely, Slot};
2use std::{alloc, ptr::NonNull};
3
4/// Origin of the partial — did we allocate it? Or is it borrowed?
5pub enum Origin<'s> {
6    /// It was allocated via `alloc::alloc` and needs to be deallocated on drop,
7    /// moving out, etc.
8    HeapAllocated,
9
10    /// It was generously lent to us by some outside code, and we are NOT
11    /// to free it (although we should still uninitialize any fields that we initialized).
12    Borrowed {
13        parent: Option<&'s Partial<'s>>,
14        init_mark: InitMark<'s>,
15    },
16}
17
18/// A partially-initialized shape.
19///
20/// This type keeps track of the initialized state of every field and only allows getting out the
21/// concrete type or the boxed concrete type or moving out of this partial into a pointer if all the
22/// fields have been initialized.
23pub struct Partial<'s> {
24    /// Address of the value we're building in memory.
25    /// If the type is a ZST, then the addr will be dangling.
26    pub(crate) addr: NonNull<u8>,
27
28    /// Where `addr` came from (ie. are we responsible for freeing it?)
29    pub(crate) origin: Origin<'s>,
30
31    /// Keeps track of which fields are initialized
32    pub(crate) init_set: InitSet64,
33
34    /// The shape we're building, asserted when building, but
35    /// also when getting fields slots, etc.
36    pub(crate) shape: ShapeDesc,
37}
38
39/// We can build a tree of partials when deserializing, so `Partial<'s>` has to be covariant over 's.
40fn _assert_partial_covariant<'long: 'short, 'short>(partial: Partial<'long>) -> Partial<'short> {
41    partial
42}
43
44impl Drop for Partial<'_> {
45    // This drop function is only really called when a partial is dropped without being fully
46    // built out. Otherwise, it's forgotten because the value has been moved elsewhere.
47    //
48    // As a result, its only job is to drop any fields that may have been initialized. And finally
49    // to free the memory for the partial itself if we own it.
50    fn drop(&mut self) {
51        match self.shape.get().innards {
52            crate::Innards::Struct { fields } => {
53                fields
54                    .iter()
55                    .enumerate()
56                    .filter_map(|(i, field)| {
57                        if self.init_set.is_set(i) {
58                            Some((field, field.shape.get().drop_in_place?))
59                        } else {
60                            None
61                        }
62                    })
63                    .for_each(|(field, drop_fn)| {
64                        unsafe {
65                            // SAFETY: field_addr is valid, aligned, and initialized.
66                            //
67                            // If the struct is a ZST, then `self.addr` is dangling.
68                            // That also means that all the fields are ZSTs, which means
69                            // the actual address we pass to the drop fn does not matter,
70                            // but we do want the side effects.
71                            //
72                            // If the struct is not a ZST, then `self.addr` is a valid address.
73                            // The fields can still be ZST and that's not a special case, really.
74                            drop_fn(self.addr.byte_add(field.offset).as_ptr());
75                        }
76                    })
77            }
78            crate::Innards::Scalar(_) => {
79                if self.init_set.is_set(0) {
80                    // Drop the scalar value if it has a drop function
81                    if let Some(drop_fn) = self.shape.get().drop_in_place {
82                        // SAFETY: self.addr is always valid for Scalar types,
83                        // even for ZSTs where it might be dangling.
84                        unsafe {
85                            drop_fn(self.addr.as_ptr());
86                        }
87                    }
88                }
89            }
90            _ => {}
91        }
92
93        self.deallocate()
94    }
95}
96
97impl Partial<'_> {
98    /// Allocates a partial on the heap for the given shape descriptor.
99    pub fn alloc(shape: ShapeDesc) -> Self {
100        let sh = shape.get();
101        let layout = sh.layout;
102        let addr = if layout.size() == 0 {
103            // ZSTs need a well-aligned address
104            sh.dangling()
105        } else {
106            let addr = unsafe { alloc::alloc(layout) };
107            if addr.is_null() {
108                alloc::handle_alloc_error(layout);
109            }
110            // SAFETY: We just allocated this memory and checked that it's not null,
111            // so it's safe to create a NonNull from it.
112            unsafe { NonNull::new_unchecked(addr) }
113        };
114
115        Self {
116            origin: Origin::HeapAllocated,
117            addr,
118            init_set: Default::default(),
119            shape,
120        }
121    }
122
123    /// Borrows a `MaybeUninit<Self>` and returns a `Partial`.
124    ///
125    /// Before calling assume_init, make sure to call Partial.build_in_place().
126    pub fn borrow<T: Shapely>(uninit: &mut std::mem::MaybeUninit<T>) -> Self {
127        Self {
128            origin: Origin::Borrowed {
129                parent: None,
130                init_mark: InitMark::Ignored,
131            },
132            addr: NonNull::new(uninit.as_mut_ptr() as _).unwrap(),
133            init_set: Default::default(),
134            shape: T::shape_desc(),
135        }
136    }
137
138    /// Checks if all fields in the struct or scalar value have been initialized.
139    /// Panics if any field is not initialized, providing details about the uninitialized field.
140    pub(crate) fn assert_all_fields_initialized(&self) {
141        let shape = self.shape.get();
142        
143        trace!(
144            "Checking initialization of \x1b[1;33m{}\x1b[0m partial at addr \x1b[1;36m{:p}\x1b[0m",
145            shape,
146            self.addr
147        );
148        match self.shape.get().innards {
149            crate::Innards::Struct { fields } => {
150                for (i, field) in fields.iter().enumerate() {
151                    if self.init_set.is_set(i) {
152                        trace!("Field \x1b[1;33m{}\x1b[0m is initialized", field.name);
153                    } else {
154                        panic!(
155                            "Field '{}' was not initialized. Complete schema:\n{:?}",
156                            field.name,
157                            self.shape.get()
158                        );
159                    }
160                }
161            }
162            crate::Innards::Scalar(_) => {
163                if !self.init_set.is_set(0) {
164                    panic!(
165                        "Scalar value was not initialized. Complete schema:\n{:?}",
166                        self.shape.get()
167                    );
168                }
169            }
170            _ => {}
171        }
172    }
173
174    /// Returns a slot for assigning this whole shape as a scalar
175    pub fn scalar_slot(&mut self) -> Option<Slot<'_>> {
176        match self.shape.get().innards {
177            crate::Innards::Scalar(_) => {
178                let slot = Slot::for_ptr(
179                    self.addr,
180                    self.shape,
181                    InitMark::Struct {
182                        index: 0,
183                        set: &mut self.init_set,
184                    },
185                );
186                Some(slot)
187            }
188            crate::Innards::Transparent(inner_shape) => {
189                let slot = Slot::for_ptr(
190                    self.addr,
191                    inner_shape,
192                    InitMark::Struct {
193                        index: 0,
194                        set: &mut self.init_set,
195                    },
196                );
197                Some(slot)
198            }
199            _ => panic!(
200                "Expected scalar innards, found {:?}",
201                self.shape.get().innards
202            ),
203        }
204    }
205
206    /// Returns a slot for initializing a field in the shape.
207    pub fn slot_by_name<'s>(&'s mut self, name: &str) -> Result<Slot<'s>, FieldError> {
208        let slot = match self.shape.get().innards {
209            crate::Innards::Struct { fields } => {
210                let (index, field) = fields
211                    .iter()
212                    .enumerate()
213                    .find(|(_, f)| f.name == name)
214                    .ok_or(FieldError::NoSuchStaticField)?;
215                let field_addr = unsafe {
216                    // SAFETY: self.addr is a valid pointer to the start of the struct,
217                    // and field.offset is the correct offset for this field within the struct.
218                    // The resulting pointer is properly aligned and within the bounds of the allocated memory.
219                    self.addr.byte_add(field.offset)
220                };
221                Slot::for_ptr(field_addr, field.shape, self.init_set.field(index))
222            }
223            crate::Innards::HashMap { value_shape } => {
224                Slot::for_hash_map(self.addr, name.to_string(), value_shape)
225            }
226            crate::Innards::Array(_shape) => return Err(FieldError::NoStaticFields),
227            crate::Innards::Transparent(_shape) => return Err(FieldError::NoStaticFields),
228            crate::Innards::Scalar(_scalar) => return Err(FieldError::NoStaticFields),
229        };
230        Ok(slot)
231    }
232
233    /// Returns a slot for initializing a field in the shape by index.
234    pub fn slot_by_index(&mut self, index: usize) -> Result<Slot<'_>, FieldError> {
235        let sh = self.shape.get();
236        let field = sh.field_by_index(index)?;
237        let field_addr = unsafe {
238            // SAFETY: self.addr is a valid pointer to the start of the struct,
239            // and field.offset is the correct offset for this field within the struct.
240            // The resulting pointer is properly aligned and within the bounds of the allocated memory.
241            self.addr.byte_add(field.offset)
242        };
243        let slot = Slot::for_ptr(field_addr, field.shape, self.init_set.field(index));
244        Ok(slot)
245    }
246
247    fn assert_matching_shape<T: Shapely>(&self) {
248        if self.shape != T::shape_desc() {
249            let partial_shape = self.shape.get();
250            let target_shape = T::shape();
251            
252            panic!(
253                "This is a partial \x1b[1;34m{}\x1b[0m, you can't build a \x1b[1;32m{}\x1b[0m out of it",
254                partial_shape,
255                target_shape,
256            );
257        }
258    }
259
260    fn deallocate(&mut self) {
261        // ZSTs don't need to be deallocated
262        if self.shape.get().layout.size() != 0 {
263            unsafe { alloc::dealloc(self.addr.as_ptr(), self.shape.get().layout) }
264        }
265    }
266
267    /// Asserts that every field has been initialized and forgets the Partial.
268    ///
269    /// This method is only used when the origin is borrowed.
270    /// If this method is not called, all fields will be freed when the Partial is dropped.
271    ///
272    /// # Panics
273    ///
274    /// This function will panic if:
275    /// - The origin is not borrowed (i.e., it's heap allocated).
276    /// - Any field is not initialized.
277    pub fn build_in_place(mut self) {
278        // ensure all fields are initialized
279        self.assert_all_fields_initialized();
280
281        match &mut self.origin {
282            Origin::Borrowed { init_mark, .. } => {
283                // Mark the borrowed field as initialized
284                init_mark.set();
285            }
286            Origin::HeapAllocated => {
287                panic!("Cannot build in place for heap allocated Partial");
288            }
289        }
290
291        // prevent field drops when the Partial is dropped
292        std::mem::forget(self);
293    }
294
295    /// Build that partial into the completed shape.
296    ///
297    /// # Panics
298    ///
299    /// This function will panic if:
300    /// - Not all the fields have been initialized.
301    /// - The generic type parameter T does not match the shape that this partial is building.
302    pub fn build<T: Shapely>(mut self) -> T {
303        self.assert_all_fields_initialized();
304        self.assert_matching_shape::<T>();
305
306        // SAFETY: We've verified that all fields are initialized and that the shape matches T.
307        // For zero-sized types, all pointer values are valid.
308        // See https://doc.rust-lang.org/stable/std/ptr/index.html#safety for more details.
309        let result = unsafe {
310            let ptr = self.addr.as_ptr() as *const T;
311            std::ptr::read(ptr)
312        };
313        trace!(
314            "Built \x1b[1;33m{}\x1b[0m successfully",
315            std::any::type_name::<T>()
316        );
317        self.deallocate();
318        std::mem::forget(self);
319        result
320    }
321    /// Build that partial into a boxed completed shape.
322    ///
323    /// # Panics
324    ///
325    /// This function will panic if:
326    /// - Not all the fields have been initialized.
327    /// - The generic type parameter T does not match the shape that this partial is building.
328    ///
329    /// # Safety
330    ///
331    /// This function uses unsafe code to create a Box from a raw pointer.
332    /// It's safe because we've verified the initialization and shape matching,
333    /// and we forget `self` to prevent double-freeing.
334    pub fn build_boxed<T: Shapely>(self) -> Box<T> {
335        self.assert_all_fields_initialized();
336        self.assert_matching_shape::<T>();
337
338        let boxed = unsafe { Box::from_raw(self.addr.as_ptr() as *mut T) };
339        std::mem::forget(self);
340        boxed
341    }
342
343    /// Moves the contents of this `Partial` into a target memory location.
344    ///
345    /// This function is useful when you need to place the fully initialized value
346    /// into a specific memory address, such as when working with FFI or custom allocators.
347    ///
348    /// # Safety
349    ///
350    /// The target pointer must be valid and properly aligned,
351    /// and must be large enough to hold the value.
352    /// The caller is responsible for ensuring that the target memory is properly deallocated
353    /// when it's no longer needed.
354    pub unsafe fn move_into(mut self, target: NonNull<u8>) {
355        self.assert_all_fields_initialized();
356        unsafe {
357            std::ptr::copy_nonoverlapping(
358                self.addr.as_ptr(),
359                target.as_ptr(),
360                // note: copy_nonoverlapping takes a count,
361                // since we're dealing with `*mut u8`, it's a byte count.
362                // if we were dealing with `*mut ()`, we'd have a nasty surprise.
363                self.shape.get().layout.size(),
364            );
365        }
366        self.deallocate();
367        std::mem::forget(self);
368    }
369
370    /// Returns the shape we're currently building.
371    pub fn shape(&self) -> ShapeDesc {
372        self.shape
373    }
374
375    /// Returns the address of the value we're building in memory.
376    pub fn addr(&self) -> NonNull<u8> {
377        self.addr
378    }
379}
380
381/// A bit array to keep track of which fields were initialized, up to 64 fields
382#[derive(Clone, Copy, Default)]
383pub struct InitSet64(u64);
384
385impl InitSet64 {
386    /// Sets the bit at the given index.
387    pub fn set(&mut self, index: usize) {
388        if index >= 64 {
389            panic!("InitSet64 can only track up to 64 fields. Index {index} is out of bounds.");
390        }
391        self.0 |= 1 << index;
392    }
393
394    /// Unsets the bit at the given index.
395    pub fn unset(&mut self, index: usize) {
396        if index >= 64 {
397            panic!("InitSet64 can only track up to 64 fields. Index {index} is out of bounds.");
398        }
399        self.0 &= !(1 << index);
400    }
401
402    /// Checks if the bit at the given index is set.
403    pub fn is_set(&self, index: usize) -> bool {
404        if index >= 64 {
405            panic!("InitSet64 can only track up to 64 fields. Index {index} is out of bounds.");
406        }
407        (self.0 & (1 << index)) != 0
408    }
409
410    /// Checks if all bits up to the given count are set.
411    pub fn all_set(&self, count: usize) -> bool {
412        if count > 64 {
413            panic!("InitSet64 can only track up to 64 fields. Count {count} is out of bounds.");
414        }
415        let mask = (1 << count) - 1;
416        self.0 & mask == mask
417    }
418
419    /// Gets an [InitMark] to track the initialization state of a single field
420    pub fn field(&mut self, index: usize) -> InitMark {
421        InitMark::Struct { index, set: self }
422    }
423}
424
425/// `InitMark` is used to track the initialization state of a single field within an `InitSet64`.
426/// It is part of a system used to progressively initialize structs, where each field's
427/// initialization status is represented by a bit in a 64-bit set.
428pub enum InitMark<'s> {
429    /// Represents a field in a struct that needs to be tracked for initialization.
430    Struct {
431        /// The index of the field in the struct (0-63).
432        index: usize,
433        /// A reference to the `InitSet64` that tracks all fields' initialization states.
434        set: &'s mut InitSet64,
435    },
436    /// Represents a field or value that doesn't need initialization tracking.
437    Ignored,
438}
439
440impl InitMark<'_> {
441    /// Marks the field as initialized by setting its corresponding bit in the `InitSet64`.
442    pub fn set(&mut self) {
443        if let Self::Struct { index, set } = self {
444            set.set(*index);
445        }
446    }
447
448    /// Marks the field as uninitialized by clearing its corresponding bit in the `InitSet64`.
449    pub fn unset(&mut self) {
450        if let Self::Struct { index, set } = self {
451            set.0 &= !(1 << *index);
452        }
453    }
454
455    /// Checks if the field is marked as initialized.
456    ///
457    /// Returns `true` if the field is initialized, `false` otherwise.
458    /// Always returns `true` for `Ignored` fields.
459    pub fn get(&self) -> bool {
460        match self {
461            Self::Struct { index, set } => set.is_set(*index),
462            Self::Ignored => true,
463        }
464    }
465}