Skip to main content

vortex_buffer/
alignment.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::ops::Deref;
6
7use vortex_error::VortexError;
8use vortex_error::VortexExpect;
9use vortex_error::vortex_err;
10
11/// Default alignment for device-to-host buffer copies.
12pub const ALIGNMENT_TO_HOST_COPY: Alignment = Alignment::new(256);
13
14/// The alignment of a buffer.
15///
16/// This type is a wrapper around `usize` that ensures the alignment is a power of 2 and fits into
17/// a `u16`.
18#[derive(Clone, Debug, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
19pub struct Alignment(usize);
20
21impl Alignment {
22    /// Create a new alignment.
23    ///
24    /// ## Panics
25    ///
26    /// Panics if `align` is not a power of 2, or is greater than `u16::MAX`.
27    #[inline]
28    pub const fn new(align: usize) -> Self {
29        assert!(align > 0, "Alignment must be greater than 0");
30        assert!(align <= u16::MAX as usize, "Alignment must fit into u16");
31        assert!(align.is_power_of_two(), "Alignment must be a power of 2");
32        Self(align)
33    }
34
35    /// Create a new 1-byte alignment.
36    #[inline]
37    pub const fn none() -> Self {
38        Self::new(1)
39    }
40
41    /// Create an alignment from the alignment of a type `T`.
42    ///
43    /// ## Example
44    ///
45    /// ```
46    /// use vortex_buffer::Alignment;
47    ///
48    /// assert_eq!(Alignment::new(4), Alignment::of::<i32>());
49    /// assert_eq!(Alignment::new(8), Alignment::of::<i64>());
50    /// assert_eq!(Alignment::new(16), Alignment::of::<u128>());
51    /// ```
52    #[inline]
53    pub const fn of<T>() -> Self {
54        Self::new(align_of::<T>())
55    }
56
57    /// Check if `self` alignment is a "larger" than `other` alignment.
58    ///
59    /// ## Example
60    ///
61    /// ```
62    /// use vortex_buffer::Alignment;
63    ///
64    /// let a = Alignment::new(4);
65    /// let b = Alignment::new(2);
66    /// assert!(a.is_aligned_to(b));
67    /// assert!(!b.is_aligned_to(a));
68    /// ```
69    #[inline]
70    pub fn is_aligned_to(&self, other: Alignment) -> bool {
71        // Since we know alignments are powers of 2, we can compare them by checking if the number
72        // of trailing zeros in the binary representation of the alignment is greater or equal.
73        self.0.trailing_zeros() >= other.0.trailing_zeros()
74    }
75
76    /// Returns the log2 of the alignment.
77    pub fn exponent(&self) -> u8 {
78        u8::try_from(self.0.trailing_zeros())
79            .vortex_expect("alignment fits into u16, so exponent fits in u7")
80    }
81
82    /// Create from the log2 exponent of the alignment.
83    ///
84    /// ## Panics
85    ///
86    /// Panics if `alignment` is not a power of 2, or is greater than `u16::MAX`.
87    #[inline]
88    pub const fn from_exponent(exponent: u8) -> Self {
89        Self::new(1 << exponent)
90    }
91}
92
93impl Display for Alignment {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(f, "{}", self.0)
96    }
97}
98
99impl Deref for Alignment {
100    type Target = usize;
101
102    #[inline]
103    fn deref(&self) -> &Self::Target {
104        &self.0
105    }
106}
107
108impl From<usize> for Alignment {
109    #[inline]
110    fn from(value: usize) -> Self {
111        Self::new(value)
112    }
113}
114
115impl From<u16> for Alignment {
116    #[inline]
117    fn from(value: u16) -> Self {
118        Self::new(usize::from(value))
119    }
120}
121
122impl From<Alignment> for usize {
123    #[inline]
124    fn from(value: Alignment) -> Self {
125        value.0
126    }
127}
128
129impl From<Alignment> for u16 {
130    #[inline]
131    fn from(value: Alignment) -> Self {
132        u16::try_from(value.0).vortex_expect("Alignment must fit into u16")
133    }
134}
135
136impl From<Alignment> for u32 {
137    #[inline]
138    fn from(value: Alignment) -> Self {
139        u32::try_from(value.0).vortex_expect("Alignment must fit into u32")
140    }
141}
142
143impl TryFrom<u32> for Alignment {
144    type Error = VortexError;
145
146    fn try_from(value: u32) -> Result<Self, Self::Error> {
147        let value = usize::try_from(value)
148            .map_err(|_| vortex_err!("Alignment must fit into usize, got {value}"))?;
149
150        if value == 0 {
151            return Err(vortex_err!("Alignment must be greater than 0"));
152        }
153        if value > u16::MAX as usize {
154            return Err(vortex_err!("Alignment must fit into u16, got {value}"));
155        }
156        if !value.is_power_of_two() {
157            return Err(vortex_err!("Alignment must be a power of 2, got {value}"));
158        }
159
160        Ok(Self(value))
161    }
162}
163
164#[cfg(test)]
165mod test {
166    use super::*;
167
168    #[test]
169    #[should_panic]
170    fn alignment_zero() {
171        Alignment::new(0);
172    }
173
174    #[test]
175    #[should_panic]
176    fn alignment_overflow() {
177        Alignment::new(u16::MAX as usize + 1);
178    }
179
180    #[test]
181    #[should_panic]
182    fn alignment_not_power_of_two() {
183        Alignment::new(3);
184    }
185
186    #[test]
187    fn alignment_exponent() {
188        let alignment = Alignment::new(1024);
189        assert_eq!(alignment.exponent(), 10);
190        assert_eq!(Alignment::from_exponent(10), alignment);
191    }
192
193    #[test]
194    fn is_aligned_to() {
195        assert!(Alignment::new(1).is_aligned_to(Alignment::new(1)));
196        assert!(Alignment::new(2).is_aligned_to(Alignment::new(1)));
197        assert!(Alignment::new(4).is_aligned_to(Alignment::new(1)));
198        assert!(!Alignment::new(1).is_aligned_to(Alignment::new(2)));
199    }
200
201    #[test]
202    fn try_from_u32() {
203        match Alignment::try_from(8u32) {
204            Ok(alignment) => assert_eq!(alignment, Alignment::new(8)),
205            Err(err) => panic!("unexpected error for valid alignment: {err}"),
206        }
207        assert!(Alignment::try_from(0u32).is_err());
208        assert!(Alignment::try_from(3u32).is_err());
209    }
210
211    #[test]
212    fn into_u32() {
213        let alignment = Alignment::new(64);
214        assert_eq!(u32::from(alignment), 64u32);
215    }
216}