vmm_sys_util/
fam.rs

1// Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2//
3// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
4//
5// SPDX-License-Identifier: BSD-3-Clause
6
7//! Trait and wrapper for working with C defined FAM structures.
8//!
9//! In C 99 an array of unknown size may appear within a struct definition as the last member
10//! (as long as there is at least one other named member).
11//! This is known as a flexible array member (FAM).
12//! Pre C99, the same behavior could be achieved using zero length arrays.
13//!
14//! Flexible Array Members are the go-to choice for working with large amounts of data
15//! prefixed by header values.
16//!
17//! For example the KVM API has many structures of this kind.
18
19#[cfg(feature = "with-serde")]
20use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
21#[cfg(feature = "with-serde")]
22use serde::{ser::SerializeTuple, Serialize, Serializer};
23use std::fmt;
24use std::fmt::{Debug, Formatter};
25#[cfg(feature = "with-serde")]
26use std::marker::PhantomData;
27use std::mem::{self, size_of};
28
29/// Errors associated with the [`FamStructWrapper`](struct.FamStructWrapper.html) struct.
30#[derive(Clone, Debug, PartialEq, Eq)]
31pub enum Error {
32    /// The max size has been exceeded
33    SizeLimitExceeded,
34}
35
36impl std::error::Error for Error {}
37
38impl fmt::Display for Error {
39    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
40        match self {
41            Self::SizeLimitExceeded => write!(f, "The max size has been exceeded"),
42        }
43    }
44}
45
46/// Trait for accessing properties of C defined FAM structures.
47///
48/// # Safety
49///
50/// This is unsafe due to the number of constraints that aren't checked:
51/// * the implementer should be a POD
52/// * the implementor should contain a flexible array member of elements of type `Entry`
53/// * `Entry` should be a POD
54/// * the implementor should ensures that the FAM length as returned by [`FamStruct::len()`]
55///   always describes correctly the length of the flexible array member.
56///
57/// Violating these may cause problems.
58///
59/// # Example
60///
61/// ```
62/// use vmm_sys_util::fam::*;
63///
64/// #[repr(C)]
65/// #[derive(Default)]
66/// pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>, [T; 0]);
67/// impl<T> __IncompleteArrayField<T> {
68///     #[inline]
69///     pub fn new() -> Self {
70///         __IncompleteArrayField(::std::marker::PhantomData, [])
71///     }
72///     #[inline]
73///     pub unsafe fn as_ptr(&self) -> *const T {
74///         ::std::mem::transmute(self)
75///     }
76///     #[inline]
77///     pub unsafe fn as_mut_ptr(&mut self) -> *mut T {
78///         ::std::mem::transmute(self)
79///     }
80///     #[inline]
81///     pub unsafe fn as_slice(&self, len: usize) -> &[T] {
82///         ::std::slice::from_raw_parts(self.as_ptr(), len)
83///     }
84///     #[inline]
85///     pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] {
86///         ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len)
87///     }
88/// }
89///
90/// #[repr(C)]
91/// #[derive(Default)]
92/// struct MockFamStruct {
93///     pub len: u32,
94///     pub padding: u32,
95///     pub entries: __IncompleteArrayField<u32>,
96/// }
97///
98/// unsafe impl FamStruct for MockFamStruct {
99///     type Entry = u32;
100///
101///     fn len(&self) -> usize {
102///         self.len as usize
103///     }
104///
105///     unsafe fn set_len(&mut self, len: usize) {
106///         self.len = len as u32
107///     }
108///
109///     fn max_len() -> usize {
110///         100
111///     }
112///
113///     fn as_slice(&self) -> &[u32] {
114///         let len = self.len();
115///         unsafe { self.entries.as_slice(len) }
116///     }
117///
118///     fn as_mut_slice(&mut self) -> &mut [u32] {
119///         let len = self.len();
120///         unsafe { self.entries.as_mut_slice(len) }
121///     }
122/// }
123///
124/// type MockFamStructWrapper = FamStructWrapper<MockFamStruct>;
125/// ```
126#[allow(clippy::len_without_is_empty)]
127pub unsafe trait FamStruct {
128    /// The type of the FAM entries
129    type Entry: Copy;
130
131    /// Get the FAM length
132    ///
133    /// These type of structures contain a member that holds the FAM length.
134    /// This method will return the value of that member.
135    fn len(&self) -> usize;
136
137    /// Set the FAM length
138    ///
139    /// These type of structures contain a member that holds the FAM length.
140    /// This method will set the value of that member.
141    ///
142    /// # Safety
143    ///
144    /// The caller needs to ensure that `len` here reflects the correct number of entries of the
145    /// flexible array part of the struct.
146    unsafe fn set_len(&mut self, len: usize);
147
148    /// Get max allowed FAM length
149    ///
150    /// This depends on each structure.
151    /// For example a structure representing the cpuid can contain at most 80 entries.
152    fn max_len() -> usize;
153
154    /// Get the FAM entries as slice
155    fn as_slice(&self) -> &[Self::Entry];
156
157    /// Get the FAM entries as mut slice
158    fn as_mut_slice(&mut self) -> &mut [Self::Entry];
159}
160
161/// A wrapper for [`FamStruct`](trait.FamStruct.html).
162///
163/// It helps in treating a [`FamStruct`](trait.FamStruct.html) similarly to an actual `Vec`.
164pub struct FamStructWrapper<T: Default + FamStruct> {
165    // This variable holds the FamStruct structure. We use a `Vec<T>` to make the allocation
166    // large enough while still being aligned for `T`. Only the first element of `Vec<T>`
167    // will actually be used as a `T`. The remaining memory in the `Vec<T>` is for `entries`,
168    // which must be contiguous. Since the entries are of type `FamStruct::Entry` we must
169    // be careful to convert the desired capacity of the `FamStructWrapper`
170    // from `FamStruct::Entry` to `T` when reserving or releasing memory.
171    mem_allocator: Vec<T>,
172}
173
174impl<T> Debug for FamStructWrapper<T>
175where
176    T: Default + FamStruct + Debug,
177    T::Entry: Debug,
178{
179    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
180        f.debug_struct("FamStructWrapper")
181            .field("fam_struct", &self.as_fam_struct_ref())
182            .field("entries", &self.as_fam_struct_ref().as_slice())
183            .finish()
184    }
185}
186
187impl<T: Default + FamStruct> FamStructWrapper<T> {
188    /// Convert FAM len to `mem_allocator` len.
189    ///
190    /// Get the capacity required by mem_allocator in order to hold
191    /// the provided number of [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry).
192    /// Returns `None` if the required length would overflow usize.
193    fn mem_allocator_len(fam_len: usize) -> Option<usize> {
194        let wrapper_size_in_bytes =
195            size_of::<T>().checked_add(fam_len.checked_mul(size_of::<T::Entry>())?)?;
196
197        wrapper_size_in_bytes
198            .checked_add(size_of::<T>().checked_sub(1)?)?
199            .checked_div(size_of::<T>())
200    }
201
202    /// Convert `mem_allocator` len to FAM len.
203    ///
204    /// Get the number of elements of type
205    /// [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry)
206    /// that fit in a mem_allocator of provided len.
207    fn fam_len(mem_allocator_len: usize) -> usize {
208        if mem_allocator_len == 0 {
209            return 0;
210        }
211
212        let array_size_in_bytes = (mem_allocator_len - 1) * size_of::<T>();
213        array_size_in_bytes / size_of::<T::Entry>()
214    }
215
216    /// Create a new FamStructWrapper with `num_elements` elements.
217    ///
218    /// The elements will be zero-initialized. The type of the elements will be
219    /// [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry).
220    ///
221    /// # Arguments
222    ///
223    /// * `num_elements` - The number of elements in the FamStructWrapper.
224    ///
225    /// # Errors
226    ///
227    /// When `num_elements` is greater than the max possible len, it returns
228    /// `Error::SizeLimitExceeded`.
229    pub fn new(num_elements: usize) -> Result<FamStructWrapper<T>, Error> {
230        if num_elements > T::max_len() {
231            return Err(Error::SizeLimitExceeded);
232        }
233        let required_mem_allocator_capacity =
234            FamStructWrapper::<T>::mem_allocator_len(num_elements)
235                .ok_or(Error::SizeLimitExceeded)?;
236
237        let mut mem_allocator = Vec::with_capacity(required_mem_allocator_capacity);
238        mem_allocator.push(T::default());
239        for _ in 1..required_mem_allocator_capacity {
240            // SAFETY: Safe as long T follows the requirements of being POD.
241            mem_allocator.push(unsafe { mem::zeroed() })
242        }
243        // SAFETY: The flexible array part of the struct has `num_elements` capacity. We just
244        // initialized this in `mem_allocator`.
245        unsafe {
246            mem_allocator[0].set_len(num_elements);
247        }
248
249        Ok(FamStructWrapper { mem_allocator })
250    }
251
252    /// Constructs a FamStructWrapper with an empty flexible array member
253    /// from the given FamStruct header.
254    ///
255    /// # Errors
256    ///
257    /// If the length stored in the header is not 0, returns [`Error::SizeLimitExceeded`]
258    pub fn from_header(header: T) -> Result<FamStructWrapper<T>, Error> {
259        if header.len() != 0 {
260            return Err(Error::SizeLimitExceeded);
261        }
262
263        // SAFETY: We are passing an array of length 1, which corresponds to exactly
264        // the header. The length inside the header is set to 0, and there are also no
265        // further elements in the vector that would constitute any T::Entry.
266        unsafe { Ok(Self::from_raw(vec![header])) }
267    }
268
269    /// Create a new FamStructWrapper from a slice of elements.
270    ///
271    /// # Arguments
272    ///
273    /// * `entries` - The slice of [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry)
274    ///   entries.
275    ///
276    /// # Errors
277    ///
278    /// When the size of `entries` is greater than the max possible len, it returns
279    /// `Error::SizeLimitExceeded`.
280    pub fn from_entries(entries: &[T::Entry]) -> Result<FamStructWrapper<T>, Error> {
281        let mut adapter = FamStructWrapper::<T>::new(entries.len())?;
282
283        {
284            // SAFETY: We are not modifying the length of the FamStruct
285            let wrapper_entries = unsafe { adapter.as_mut_fam_struct().as_mut_slice() };
286            wrapper_entries.copy_from_slice(entries);
287        }
288
289        Ok(adapter)
290    }
291
292    /// Create a new FamStructWrapper from the raw content represented as `Vec<T>`.
293    ///
294    /// Sometimes we already have the raw content of an FAM struct represented as `Vec<T>`,
295    /// and want to use the FamStructWrapper as accessors.
296    ///
297    /// # Arguments
298    ///
299    /// * `content` - The raw content represented as `Vec[T]`.
300    ///
301    /// # Safety
302    ///
303    /// This function is unsafe because the caller needs to ensure that the raw content is
304    /// correctly layed out.
305    pub unsafe fn from_raw(content: Vec<T>) -> Self {
306        debug_assert_ne!(content.len(), 0);
307        debug_assert!(content[0].len() <= Self::fam_len(content.len()));
308
309        FamStructWrapper {
310            mem_allocator: content,
311        }
312    }
313
314    /// Consume the FamStructWrapper and return the raw content as `Vec<T>`.
315    pub fn into_raw(self) -> Vec<T> {
316        self.mem_allocator
317    }
318
319    /// Get a reference to the actual [`FamStruct`](trait.FamStruct.html) instance.
320    pub fn as_fam_struct_ref(&self) -> &T {
321        &self.mem_allocator[0]
322    }
323
324    /// Get a mut reference to the actual [`FamStruct`](trait.FamStruct.html) instance.
325    ///
326    /// # Safety
327    ///
328    /// Callers must not use the reference returned to modify the `len` field of the underlying
329    /// `FamStruct`. See also the top-level documentation of [`FamStruct`].
330    pub unsafe fn as_mut_fam_struct(&mut self) -> &mut T {
331        &mut self.mem_allocator[0]
332    }
333
334    /// Get a pointer to the [`FamStruct`](trait.FamStruct.html) instance.
335    ///
336    /// The caller must ensure that the fam_struct outlives the pointer this
337    /// function returns, or else it will end up pointing to garbage.
338    ///
339    /// Modifying the container referenced by this pointer may cause its buffer
340    /// to be reallocated, which would also make any pointers to it invalid.
341    pub fn as_fam_struct_ptr(&self) -> *const T {
342        self.as_fam_struct_ref()
343    }
344
345    /// Get a mutable pointer to the [`FamStruct`](trait.FamStruct.html) instance.
346    ///
347    /// The caller must ensure that the fam_struct outlives the pointer this
348    /// function returns, or else it will end up pointing to garbage.
349    ///
350    /// Modifying the container referenced by this pointer may cause its buffer
351    /// to be reallocated, which would also make any pointers to it invalid.
352    pub fn as_mut_fam_struct_ptr(&mut self) -> *mut T {
353        // SAFETY: We do not change the length of the underlying FamStruct.
354        unsafe { self.as_mut_fam_struct() }
355    }
356
357    /// Get the elements slice.
358    pub fn as_slice(&self) -> &[T::Entry] {
359        self.as_fam_struct_ref().as_slice()
360    }
361
362    /// Get the mutable elements slice.
363    pub fn as_mut_slice(&mut self) -> &mut [T::Entry] {
364        // SAFETY: We do not change the length of the underlying FamStruct.
365        unsafe { self.as_mut_fam_struct() }.as_mut_slice()
366    }
367
368    /// Get the number of elements of type `FamStruct::Entry` currently in the vec.
369    fn len(&self) -> usize {
370        self.as_fam_struct_ref().len()
371    }
372
373    /// Get the capacity of the `FamStructWrapper`
374    ///
375    /// The capacity is measured in elements of type `FamStruct::Entry`.
376    fn capacity(&self) -> usize {
377        FamStructWrapper::<T>::fam_len(self.mem_allocator.capacity())
378    }
379
380    /// Reserve additional capacity.
381    ///
382    /// Reserve capacity for at least `additional` more
383    /// [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry) elements.
384    ///
385    /// If the capacity is already reserved, this method doesn't do anything.
386    /// If not this will trigger a reallocation of the underlying buffer.
387    fn reserve(&mut self, additional: usize) -> Result<(), Error> {
388        let desired_capacity = self.len() + additional;
389        if desired_capacity <= self.capacity() {
390            return Ok(());
391        }
392
393        let current_mem_allocator_len = self.mem_allocator.len();
394        let required_mem_allocator_len = FamStructWrapper::<T>::mem_allocator_len(desired_capacity)
395            .ok_or(Error::SizeLimitExceeded)?;
396        let additional_mem_allocator_len = required_mem_allocator_len - current_mem_allocator_len;
397
398        self.mem_allocator.reserve(additional_mem_allocator_len);
399
400        Ok(())
401    }
402
403    /// Update the length of the FamStructWrapper.
404    ///
405    /// The length of `self` will be updated to the specified value.
406    /// The length of the `T` structure and of `self.mem_allocator` will be updated accordingly.
407    /// If the len is increased additional capacity will be reserved.
408    /// If the len is decreased the unnecessary memory will be deallocated.
409    ///
410    /// This method might trigger reallocations of the underlying buffer.
411    ///
412    /// # Errors
413    ///
414    /// When len is greater than the max possible len it returns Error::SizeLimitExceeded.
415    fn set_len(&mut self, len: usize) -> Result<(), Error> {
416        let additional_elements = isize::try_from(len)
417            .and_then(|len| isize::try_from(self.len()).map(|self_len| len - self_len))
418            .map_err(|_| Error::SizeLimitExceeded)?;
419
420        // If len == self.len there's nothing to do.
421        if additional_elements == 0 {
422            return Ok(());
423        }
424
425        // If the len needs to be increased:
426        if additional_elements > 0 {
427            // Check if the new len is valid.
428            if len > T::max_len() {
429                return Err(Error::SizeLimitExceeded);
430            }
431            // Reserve additional capacity.
432            self.reserve(additional_elements as usize)?;
433        }
434
435        let current_mem_allocator_len = self.mem_allocator.len();
436        let required_mem_allocator_len =
437            FamStructWrapper::<T>::mem_allocator_len(len).ok_or(Error::SizeLimitExceeded)?;
438        // Update the len of the `mem_allocator`.
439        // SAFETY: This is safe since enough capacity has been reserved.
440        unsafe {
441            self.mem_allocator.set_len(required_mem_allocator_len);
442        }
443        // Zero-initialize the additional elements if any.
444        for i in current_mem_allocator_len..required_mem_allocator_len {
445            // SAFETY: Safe as long as the trait is only implemented for POD. This is a requirement
446            // for the trait implementation.
447            self.mem_allocator[i] = unsafe { mem::zeroed() }
448        }
449        // Update the len of the underlying `FamStruct`.
450        // SAFETY: We just adjusted the memory for the underlying `mem_allocator` to hold `len`
451        // entries.
452        unsafe {
453            self.as_mut_fam_struct().set_len(len);
454        }
455
456        // If the len needs to be decreased, deallocate unnecessary memory
457        if additional_elements < 0 {
458            self.mem_allocator.shrink_to_fit();
459        }
460
461        Ok(())
462    }
463
464    /// Append an element.
465    ///
466    /// # Arguments
467    ///
468    /// * `entry` - The element that will be appended to the end of the collection.
469    ///
470    /// # Errors
471    ///
472    /// When len is already equal to max possible len it returns Error::SizeLimitExceeded.
473    pub fn push(&mut self, entry: T::Entry) -> Result<(), Error> {
474        let new_len = self.len() + 1;
475        self.set_len(new_len)?;
476        self.as_mut_slice()[new_len - 1] = entry;
477
478        Ok(())
479    }
480
481    /// Retain only the elements specified by the predicate.
482    ///
483    /// # Arguments
484    ///
485    /// * `f` - The function used to evaluate whether an entry will be kept or not.
486    ///   When `f` returns `true` the entry is kept.
487    pub fn retain<P>(&mut self, mut f: P)
488    where
489        P: FnMut(&T::Entry) -> bool,
490    {
491        let mut num_kept_entries = 0;
492        {
493            let entries = self.as_mut_slice();
494            for entry_idx in 0..entries.len() {
495                let keep = f(&entries[entry_idx]);
496                if keep {
497                    entries[num_kept_entries] = entries[entry_idx];
498                    num_kept_entries += 1;
499                }
500            }
501        }
502
503        // This is safe since this method is not increasing the len
504        self.set_len(num_kept_entries).expect("invalid length");
505    }
506}
507
508impl<T: Default + FamStruct + PartialEq> PartialEq for FamStructWrapper<T>
509where
510    T::Entry: PartialEq,
511{
512    fn eq(&self, other: &FamStructWrapper<T>) -> bool {
513        self.as_fam_struct_ref() == other.as_fam_struct_ref() && self.as_slice() == other.as_slice()
514    }
515}
516
517impl<T: Default + FamStruct> Clone for FamStructWrapper<T> {
518    fn clone(&self) -> Self {
519        // The number of entries (self.as_slice().len()) can't be > T::max_len() since `self` is a
520        // valid `FamStructWrapper`. This makes the .unwrap() safe.
521        let required_mem_allocator_capacity =
522            FamStructWrapper::<T>::mem_allocator_len(self.as_slice().len()).unwrap();
523
524        let mut mem_allocator = Vec::with_capacity(required_mem_allocator_capacity);
525
526        // SAFETY: This is safe as long as the requirements for the `FamStruct` trait to be safe
527        // are met (the implementing type and the entries elements are POD, therefore `Copy`, so
528        // memory safety can't be violated by the ownership of `fam_struct`). It is also safe
529        // because we're trying to read a T from a `&T` that is pointing to a properly initialized
530        // and aligned T.
531        unsafe {
532            let fam_struct: T = std::ptr::read(self.as_fam_struct_ref());
533            mem_allocator.push(fam_struct);
534        }
535        for _ in 1..required_mem_allocator_capacity {
536            mem_allocator.push(
537                // SAFETY: This is safe as long as T respects the FamStruct trait and is a POD.
538                unsafe { mem::zeroed() },
539            )
540        }
541
542        let mut adapter = FamStructWrapper { mem_allocator };
543        {
544            let wrapper_entries = adapter.as_mut_slice();
545            wrapper_entries.copy_from_slice(self.as_slice());
546        }
547        adapter
548    }
549}
550
551#[cfg(feature = "with-serde")]
552impl<T: Default + FamStruct + Serialize> Serialize for FamStructWrapper<T>
553where
554    <T as FamStruct>::Entry: serde::Serialize,
555{
556    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
557    where
558        S: Serializer,
559    {
560        let mut s = serializer.serialize_tuple(2)?;
561        s.serialize_element(self.as_fam_struct_ref())?;
562        s.serialize_element(self.as_slice())?;
563        s.end()
564    }
565}
566
567#[cfg(feature = "with-serde")]
568impl<'de, T: Default + FamStruct + Deserialize<'de>> Deserialize<'de> for FamStructWrapper<T>
569where
570    <T as FamStruct>::Entry: std::marker::Copy + serde::Deserialize<'de>,
571{
572    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
573    where
574        D: Deserializer<'de>,
575    {
576        struct FamStructWrapperVisitor<X> {
577            dummy: PhantomData<X>,
578        }
579
580        impl<'de, X: Default + FamStruct + Deserialize<'de>> Visitor<'de> for FamStructWrapperVisitor<X>
581        where
582            <X as FamStruct>::Entry: std::marker::Copy + serde::Deserialize<'de>,
583        {
584            type Value = FamStructWrapper<X>;
585
586            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
587                formatter.write_str("FamStructWrapper")
588            }
589
590            fn visit_seq<V>(self, mut seq: V) -> Result<FamStructWrapper<X>, V::Error>
591            where
592                V: SeqAccess<'de>,
593            {
594                use serde::de::Error;
595
596                let header: X = seq
597                    .next_element()?
598                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
599                let entries: Vec<X::Entry> = seq
600                    .next_element()?
601                    .ok_or_else(|| de::Error::invalid_length(1, &self))?;
602
603                if header.len() != entries.len() {
604                    let msg = format!(
605                        "Mismatch between length of FAM specified in FamStruct header ({}) \
606                         and actual size of FAM ({})",
607                        header.len(),
608                        entries.len()
609                    );
610                    return Err(V::Error::custom(msg));
611                }
612
613                let mut result: Self::Value = FamStructWrapper::from_entries(entries.as_slice())
614                    .map_err(|e| V::Error::custom(format!("{:?}", e)))?;
615                result.mem_allocator[0] = header;
616                Ok(result)
617            }
618        }
619
620        deserializer.deserialize_tuple(2, FamStructWrapperVisitor { dummy: PhantomData })
621    }
622}
623
624/// Generate `FamStruct` implementation for structs with flexible array member.
625#[macro_export]
626macro_rules! generate_fam_struct_impl {
627    ($struct_type: ty, $entry_type: ty, $entries_name: ident,
628     $field_type: ty, $field_name: ident, $max: expr) => {
629        unsafe impl FamStruct for $struct_type {
630            type Entry = $entry_type;
631
632            fn len(&self) -> usize {
633                self.$field_name as usize
634            }
635
636            unsafe fn set_len(&mut self, len: usize) {
637                self.$field_name = len as $field_type;
638            }
639
640            fn max_len() -> usize {
641                $max
642            }
643
644            fn as_slice(&self) -> &[<Self as FamStruct>::Entry] {
645                let len = self.len();
646                unsafe { self.$entries_name.as_slice(len) }
647            }
648
649            fn as_mut_slice(&mut self) -> &mut [<Self as FamStruct>::Entry] {
650                let len = self.len();
651                unsafe { self.$entries_name.as_mut_slice(len) }
652            }
653        }
654    };
655}
656
657#[cfg(test)]
658mod tests {
659    #![allow(clippy::undocumented_unsafe_blocks)]
660
661    #[cfg(feature = "with-serde")]
662    use serde_derive::{Deserialize, Serialize};
663
664    use super::*;
665
666    const MAX_LEN: usize = 100;
667
668    #[repr(C)]
669    #[derive(Default, Debug, PartialEq, Eq)]
670    pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>, [T; 0]);
671    impl<T> __IncompleteArrayField<T> {
672        #[inline]
673        pub fn new() -> Self {
674            __IncompleteArrayField(::std::marker::PhantomData, [])
675        }
676        #[inline]
677        pub unsafe fn as_ptr(&self) -> *const T {
678            self as *const __IncompleteArrayField<T> as *const T
679        }
680        #[inline]
681        pub unsafe fn as_mut_ptr(&mut self) -> *mut T {
682            self as *mut __IncompleteArrayField<T> as *mut T
683        }
684        #[inline]
685        pub unsafe fn as_slice(&self, len: usize) -> &[T] {
686            ::std::slice::from_raw_parts(self.as_ptr(), len)
687        }
688        #[inline]
689        pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] {
690            ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len)
691        }
692    }
693
694    #[cfg(feature = "with-serde")]
695    impl<T> Serialize for __IncompleteArrayField<T> {
696        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
697        where
698            S: Serializer,
699        {
700            [0u8; 0].serialize(serializer)
701        }
702    }
703
704    #[cfg(feature = "with-serde")]
705    impl<'de, T> Deserialize<'de> for __IncompleteArrayField<T> {
706        fn deserialize<D>(_: D) -> std::result::Result<Self, D::Error>
707        where
708            D: Deserializer<'de>,
709        {
710            Ok(__IncompleteArrayField::new())
711        }
712    }
713
714    #[repr(C)]
715    #[derive(Default, PartialEq)]
716    struct MockFamStruct {
717        pub len: u32,
718        pub padding: u32,
719        pub entries: __IncompleteArrayField<u32>,
720    }
721
722    generate_fam_struct_impl!(MockFamStruct, u32, entries, u32, len, 100);
723
724    type MockFamStructWrapper = FamStructWrapper<MockFamStruct>;
725
726    const ENTRIES_OFFSET: usize = 2;
727
728    const FAM_LEN_TO_MEM_ALLOCATOR_LEN: &[(usize, usize)] = &[
729        (0, 1),
730        (1, 2),
731        (2, 2),
732        (3, 3),
733        (4, 3),
734        (5, 4),
735        (10, 6),
736        (50, 26),
737        (100, 51),
738    ];
739
740    const MEM_ALLOCATOR_LEN_TO_FAM_LEN: &[(usize, usize)] = &[
741        (0, 0),
742        (1, 0),
743        (2, 2),
744        (3, 4),
745        (4, 6),
746        (5, 8),
747        (10, 18),
748        (50, 98),
749        (100, 198),
750    ];
751
752    #[test]
753    fn test_mem_allocator_len() {
754        for pair in FAM_LEN_TO_MEM_ALLOCATOR_LEN {
755            let fam_len = pair.0;
756            let mem_allocator_len = pair.1;
757            assert_eq!(
758                Some(mem_allocator_len),
759                MockFamStructWrapper::mem_allocator_len(fam_len)
760            );
761        }
762    }
763
764    #[repr(C)]
765    #[derive(Default, PartialEq)]
766    struct MockFamStructU8 {
767        pub len: u32,
768        pub padding: u32,
769        pub entries: __IncompleteArrayField<u8>,
770    }
771    generate_fam_struct_impl!(MockFamStructU8, u8, entries, u32, len, 100);
772    type MockFamStructWrapperU8 = FamStructWrapper<MockFamStructU8>;
773    #[test]
774    fn test_invalid_type_conversion() {
775        let mut adapter = MockFamStructWrapperU8::new(10).unwrap();
776        assert!(matches!(
777            adapter.set_len(0xffff_ffff_ffff_ff00),
778            Err(Error::SizeLimitExceeded)
779        ));
780    }
781
782    #[test]
783    fn test_wrapper_len() {
784        for pair in MEM_ALLOCATOR_LEN_TO_FAM_LEN {
785            let mem_allocator_len = pair.0;
786            let fam_len = pair.1;
787            assert_eq!(fam_len, MockFamStructWrapper::fam_len(mem_allocator_len));
788        }
789    }
790
791    #[test]
792    fn test_new() {
793        let num_entries = 10;
794
795        let adapter = MockFamStructWrapper::new(num_entries).unwrap();
796        assert_eq!(num_entries, adapter.capacity());
797
798        let u32_slice = unsafe {
799            std::slice::from_raw_parts(
800                adapter.as_fam_struct_ptr() as *const u32,
801                num_entries + ENTRIES_OFFSET,
802            )
803        };
804        assert_eq!(num_entries, u32_slice[0] as usize);
805        for entry in u32_slice[1..].iter() {
806            assert_eq!(*entry, 0);
807        }
808
809        // It's okay to create a `FamStructWrapper` with the maximum allowed number of entries.
810        let adapter = MockFamStructWrapper::new(MockFamStruct::max_len()).unwrap();
811        assert_eq!(MockFamStruct::max_len(), adapter.capacity());
812
813        assert!(matches!(
814            MockFamStructWrapper::new(MockFamStruct::max_len() + 1),
815            Err(Error::SizeLimitExceeded)
816        ));
817    }
818
819    #[test]
820    fn test_from_entries() {
821        let num_entries: usize = 10;
822
823        let mut entries = Vec::new();
824        for i in 0..num_entries {
825            entries.push(i as u32);
826        }
827
828        let adapter = MockFamStructWrapper::from_entries(entries.as_slice()).unwrap();
829        let u32_slice = unsafe {
830            std::slice::from_raw_parts(
831                adapter.as_fam_struct_ptr() as *const u32,
832                num_entries + ENTRIES_OFFSET,
833            )
834        };
835        assert_eq!(num_entries, u32_slice[0] as usize);
836        for (i, &value) in entries.iter().enumerate().take(num_entries) {
837            assert_eq!(adapter.as_slice()[i], value);
838        }
839
840        let mut entries = Vec::new();
841        for i in 0..MockFamStruct::max_len() + 1 {
842            entries.push(i as u32);
843        }
844
845        // Can't create a `FamStructWrapper` with a number of entries > MockFamStruct::max_len().
846        assert!(matches!(
847            MockFamStructWrapper::from_entries(entries.as_slice()),
848            Err(Error::SizeLimitExceeded)
849        ));
850    }
851
852    #[test]
853    fn test_entries_slice() {
854        let num_entries = 10;
855        let mut adapter = MockFamStructWrapper::new(num_entries).unwrap();
856
857        let expected_slice = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
858
859        {
860            let mut_entries_slice = adapter.as_mut_slice();
861            mut_entries_slice.copy_from_slice(expected_slice);
862        }
863
864        let u32_slice = unsafe {
865            std::slice::from_raw_parts(
866                adapter.as_fam_struct_ptr() as *const u32,
867                num_entries + ENTRIES_OFFSET,
868            )
869        };
870        assert_eq!(expected_slice, &u32_slice[ENTRIES_OFFSET..]);
871        assert_eq!(expected_slice, adapter.as_slice());
872    }
873
874    #[test]
875    fn test_reserve() {
876        let mut adapter = MockFamStructWrapper::new(0).unwrap();
877
878        // test that the right capacity is reserved
879        for pair in FAM_LEN_TO_MEM_ALLOCATOR_LEN {
880            let num_elements = pair.0;
881            let required_mem_allocator_len = pair.1;
882
883            adapter.reserve(num_elements).unwrap();
884
885            assert!(adapter.mem_allocator.capacity() >= required_mem_allocator_len);
886            assert_eq!(0, adapter.len());
887            assert!(adapter.capacity() >= num_elements);
888        }
889
890        // test that when the capacity is already reserved, the method doesn't do anything
891        let current_capacity = adapter.capacity();
892        adapter.reserve(current_capacity - 1).unwrap();
893        assert_eq!(current_capacity, adapter.capacity());
894    }
895
896    #[test]
897    fn test_set_len() {
898        let mut desired_len = 0;
899        let mut adapter = MockFamStructWrapper::new(desired_len).unwrap();
900
901        // keep initial len
902        assert!(adapter.set_len(desired_len).is_ok());
903        assert_eq!(adapter.len(), desired_len);
904
905        // increase len
906        desired_len = 10;
907        assert!(adapter.set_len(desired_len).is_ok());
908        // check that the len has been increased and zero-initialized elements have been added
909        assert_eq!(adapter.len(), desired_len);
910        for element in adapter.as_slice() {
911            assert_eq!(*element, 0_u32);
912        }
913
914        // decrease len
915        desired_len = 5;
916        assert!(adapter.set_len(desired_len).is_ok());
917        assert_eq!(adapter.len(), desired_len);
918    }
919
920    #[test]
921    fn test_push() {
922        let mut adapter = MockFamStructWrapper::new(0).unwrap();
923
924        for i in 0..MAX_LEN {
925            assert!(adapter.push(i as u32).is_ok());
926            assert_eq!(adapter.as_slice()[i], i as u32);
927            assert_eq!(adapter.len(), i + 1);
928            assert!(
929                adapter.mem_allocator.capacity()
930                    >= MockFamStructWrapper::mem_allocator_len(i + 1).unwrap()
931            );
932        }
933
934        assert!(adapter.push(0).is_err());
935    }
936
937    #[test]
938    fn test_retain() {
939        let mut adapter = MockFamStructWrapper::new(0).unwrap();
940
941        let mut num_retained_entries = 0;
942        for i in 0..MAX_LEN {
943            assert!(adapter.push(i as u32).is_ok());
944            if i % 2 == 0 {
945                num_retained_entries += 1;
946            }
947        }
948
949        adapter.retain(|entry| entry % 2 == 0);
950
951        for entry in adapter.as_slice().iter() {
952            assert_eq!(0, entry % 2);
953        }
954        assert_eq!(adapter.len(), num_retained_entries);
955        assert!(
956            adapter.mem_allocator.capacity()
957                >= MockFamStructWrapper::mem_allocator_len(num_retained_entries).unwrap()
958        );
959    }
960
961    #[test]
962    fn test_partial_eq() {
963        let mut wrapper_1 = MockFamStructWrapper::new(0).unwrap();
964        let mut wrapper_2 = MockFamStructWrapper::new(0).unwrap();
965        let mut wrapper_3 = MockFamStructWrapper::new(0).unwrap();
966
967        for i in 0..MAX_LEN {
968            assert!(wrapper_1.push(i as u32).is_ok());
969            assert!(wrapper_2.push(i as u32).is_ok());
970            assert!(wrapper_3.push(0).is_ok());
971        }
972
973        assert!(wrapper_1 == wrapper_2);
974        assert!(wrapper_1 != wrapper_3);
975    }
976
977    #[test]
978    fn test_clone() {
979        let mut adapter = MockFamStructWrapper::new(0).unwrap();
980
981        for i in 0..MAX_LEN {
982            assert!(adapter.push(i as u32).is_ok());
983        }
984
985        assert!(adapter == adapter.clone());
986    }
987
988    #[test]
989    fn test_from_header() {
990        let header = MockFamStruct::default();
991        let wrapper = MockFamStructWrapper::from_header(header).unwrap();
992        assert_eq!(wrapper.len(), 0);
993        assert_eq!(wrapper.as_fam_struct_ref().len, 0);
994
995        let header = MockFamStruct {
996            len: 100,
997            ..Default::default()
998        };
999        let error = MockFamStructWrapper::from_header(header);
1000        assert!(matches!(error, Err(Error::SizeLimitExceeded)));
1001    }
1002
1003    #[test]
1004    fn test_raw_content() {
1005        let data = vec![
1006            MockFamStruct {
1007                len: 2,
1008                padding: 5,
1009                entries: __IncompleteArrayField::new(),
1010            },
1011            MockFamStruct {
1012                len: 0xA5,
1013                padding: 0x1e,
1014                entries: __IncompleteArrayField::new(),
1015            },
1016        ];
1017
1018        let mut wrapper = unsafe { MockFamStructWrapper::from_raw(data) };
1019        {
1020            let payload = wrapper.as_slice();
1021            assert_eq!(payload[0], 0xA5);
1022            assert_eq!(payload[1], 0x1e);
1023        }
1024        assert_eq!(unsafe { wrapper.as_mut_fam_struct() }.padding, 5);
1025        let data = wrapper.into_raw();
1026        assert_eq!(data[0].len, 2);
1027        assert_eq!(data[0].padding, 5);
1028    }
1029
1030    #[cfg(feature = "with-serde")]
1031    #[test]
1032    fn test_ser_deser() {
1033        #[repr(C)]
1034        #[derive(Default, PartialEq)]
1035        #[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))]
1036        struct Message {
1037            pub len: u32,
1038            pub padding: u32,
1039            pub value: u32,
1040            #[cfg_attr(feature = "with-serde", serde(skip))]
1041            pub entries: __IncompleteArrayField<u32>,
1042        }
1043
1044        generate_fam_struct_impl!(Message, u32, entries, u32, len, 100);
1045
1046        type MessageFamStructWrapper = FamStructWrapper<Message>;
1047
1048        let data = vec![
1049            Message {
1050                len: 2,
1051                padding: 0,
1052                value: 42,
1053                entries: __IncompleteArrayField::new(),
1054            },
1055            Message {
1056                len: 0xA5,
1057                padding: 0x1e,
1058                value: 0,
1059                entries: __IncompleteArrayField::new(),
1060            },
1061        ];
1062
1063        let wrapper = unsafe { MessageFamStructWrapper::from_raw(data) };
1064        let data_ser = serde_json::to_string(&wrapper).unwrap();
1065        assert_eq!(
1066            data_ser,
1067            "[{\"len\":2,\"padding\":0,\"value\":42},[165,30]]"
1068        );
1069        let data_deser =
1070            serde_json::from_str::<MessageFamStructWrapper>(data_ser.as_str()).unwrap();
1071        assert!(wrapper.eq(&data_deser));
1072
1073        let bad_data_ser = r#"{"foo": "bar"}"#;
1074        assert!(serde_json::from_str::<MessageFamStructWrapper>(bad_data_ser).is_err());
1075
1076        #[repr(C)]
1077        #[derive(Default)]
1078        #[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))]
1079        struct Message2 {
1080            pub len: u32,
1081            pub padding: u32,
1082            pub value: u32,
1083            #[cfg_attr(feature = "with-serde", serde(skip))]
1084            pub entries: __IncompleteArrayField<u32>,
1085        }
1086
1087        // Maximum number of entries = 1, so the deserialization should fail because of this reason.
1088        generate_fam_struct_impl!(Message2, u32, entries, u32, len, 1);
1089
1090        type Message2FamStructWrapper = FamStructWrapper<Message2>;
1091        assert!(serde_json::from_str::<Message2FamStructWrapper>(data_ser.as_str()).is_err());
1092    }
1093
1094    #[test]
1095    fn test_clone_multiple_fields() {
1096        #[derive(Default, PartialEq)]
1097        #[repr(C)]
1098        struct Foo {
1099            index: u32,
1100            length: u16,
1101            flags: u32,
1102            entries: __IncompleteArrayField<u32>,
1103        }
1104
1105        generate_fam_struct_impl!(Foo, u32, entries, u16, length, 100);
1106
1107        type FooFamStructWrapper = FamStructWrapper<Foo>;
1108
1109        let mut wrapper = FooFamStructWrapper::new(0).unwrap();
1110        // SAFETY: We do play with length here, but that's just for testing purposes :)
1111        unsafe {
1112            wrapper.as_mut_fam_struct().index = 1;
1113            wrapper.as_mut_fam_struct().flags = 2;
1114            wrapper.as_mut_fam_struct().length = 3;
1115            wrapper.push(3).unwrap();
1116            wrapper.push(14).unwrap();
1117            assert_eq!(wrapper.as_slice().len(), 3 + 2);
1118            assert_eq!(wrapper.as_slice()[3], 3);
1119            assert_eq!(wrapper.as_slice()[3 + 1], 14);
1120
1121            let mut wrapper2 = wrapper.clone();
1122            assert_eq!(
1123                wrapper.as_mut_fam_struct().index,
1124                wrapper2.as_mut_fam_struct().index
1125            );
1126            assert_eq!(
1127                wrapper.as_mut_fam_struct().length,
1128                wrapper2.as_mut_fam_struct().length
1129            );
1130            assert_eq!(
1131                wrapper.as_mut_fam_struct().flags,
1132                wrapper2.as_mut_fam_struct().flags
1133            );
1134            assert_eq!(wrapper.as_slice(), wrapper2.as_slice());
1135            assert_eq!(
1136                wrapper2.as_slice().len(),
1137                wrapper2.as_mut_fam_struct().length as usize
1138            );
1139            assert!(wrapper == wrapper2);
1140
1141            wrapper.as_mut_fam_struct().index = 3;
1142            assert!(wrapper != wrapper2);
1143
1144            wrapper.as_mut_fam_struct().length = 7;
1145            assert!(wrapper != wrapper2);
1146
1147            wrapper.push(1).unwrap();
1148            assert_eq!(wrapper.as_mut_fam_struct().length, 8);
1149            assert!(wrapper != wrapper2);
1150
1151            let mut wrapper2 = wrapper.clone();
1152            assert!(wrapper == wrapper2);
1153
1154            // Dropping the original variable should not affect its clone.
1155            drop(wrapper);
1156            assert_eq!(wrapper2.as_mut_fam_struct().index, 3);
1157            assert_eq!(wrapper2.as_mut_fam_struct().length, 8);
1158            assert_eq!(wrapper2.as_mut_fam_struct().flags, 2);
1159            assert_eq!(wrapper2.as_slice(), [0, 0, 0, 3, 14, 0, 0, 1]);
1160        }
1161    }
1162
1163    #[cfg(feature = "with-serde")]
1164    #[test]
1165    fn test_bad_deserialize() {
1166        #[repr(C)]
1167        #[derive(Default, Debug, PartialEq, Serialize, Deserialize)]
1168        struct Foo {
1169            pub len: u32,
1170            pub padding: u32,
1171            pub entries: __IncompleteArrayField<u32>,
1172        }
1173
1174        generate_fam_struct_impl!(Foo, u32, entries, u32, len, 100);
1175
1176        let state = FamStructWrapper::<Foo>::new(0).unwrap();
1177        let mut bytes = bincode::serialize(&state).unwrap();
1178
1179        // The `len` field of the header is the first to be serialized.
1180        // Writing at position 0 of the serialized data should change its value.
1181        bytes[0] = 255;
1182
1183        assert!(
1184            matches!(bincode::deserialize::<FamStructWrapper<Foo>>(&bytes).map_err(|boxed| *boxed), Err(bincode::ErrorKind::Custom(s)) if s == *"Mismatch between length of FAM specified in FamStruct header (255) and actual size of FAM (0)")
1185        );
1186    }
1187}