Skip to main content

solana_nullable/
maybe_null.rs

1//! Generic `Option`-like wrapper for types that can reserve a designated null
2//! value without adding a tag byte.
3//!
4//! For example, a 64-bit unsigned integer can designate `0` as a `None` value.
5//! This is equivalent to
6//! [`Option<NonZeroU64>`](https://doc.rust-lang.org/std/num/type.NonZeroU64.html)
7//! and provides the same memory layout optimization.
8
9use crate::Nullable;
10#[cfg(feature = "bytemuck")]
11use bytemuck::{Pod, Zeroable};
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14#[cfg(feature = "wincode")]
15use wincode::{SchemaRead, SchemaWrite};
16#[cfg(feature = "borsh")]
17use {
18    alloc::format,
19    borsh::{BorshDeserialize, BorshSchema, BorshSerialize},
20};
21
22/// A wrapper that can be used as an `Option<T>` without requiring extra space
23/// to indicate whether the value is `Some` or `None`.
24///
25/// This can be used when a specific value of `T` indicates that its value is
26/// `None`.
27#[repr(transparent)]
28#[cfg_attr(
29    feature = "borsh",
30    derive(BorshDeserialize, BorshSerialize, BorshSchema)
31)]
32#[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))]
33#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
34pub struct MaybeNull<T: Nullable>(T);
35
36/// # Safety
37///
38/// `MaybeNull<T>` where `T: ZeroCopy` is trivially zero-copy.
39#[cfg(feature = "wincode")]
40unsafe impl<T, C> wincode::config::ZeroCopy<C> for MaybeNull<T>
41where
42    C: wincode::config::ConfigCore,
43    T: Nullable + wincode::config::ZeroCopy<C>,
44{
45}
46
47impl<T: Nullable> Default for MaybeNull<T> {
48    fn default() -> Self {
49        Self(T::NONE)
50    }
51}
52
53impl<T: Nullable> MaybeNull<T> {
54    /// Returns the contained value as an `Option`.
55    #[inline]
56    pub fn get(self) -> Option<T> {
57        if self.0.is_none() {
58            None
59        } else {
60            Some(self.0)
61        }
62    }
63
64    /// Returns a reference to the contained value as an `Option`.
65    #[inline]
66    pub fn as_ref(&self) -> Option<&T> {
67        if self.0.is_none() {
68            None
69        } else {
70            Some(&self.0)
71        }
72    }
73
74    /// Returns a mutable reference to the contained value as an `Option`.
75    #[inline]
76    pub fn as_mut(&mut self) -> Option<&mut T> {
77        if self.0.is_none() {
78            None
79        } else {
80            Some(&mut self.0)
81        }
82    }
83
84    /// Maps a `MaybeNull<T>` to an `Option<T>` by copying the contents of the option.
85    #[inline]
86    pub fn copied(&self) -> Option<T>
87    where
88        T: Copy,
89    {
90        self.as_ref().copied()
91    }
92
93    /// Maps a `MaybeNull<T>` to an `Option<T>` by cloning the contents of the option.
94    #[inline]
95    pub fn cloned(&self) -> Option<T>
96    where
97        T: Clone,
98    {
99        self.as_ref().cloned()
100    }
101}
102
103impl<T: Nullable> From<T> for MaybeNull<T> {
104    fn from(value: T) -> Self {
105        MaybeNull(value)
106    }
107}
108
109impl<T: Nullable> From<MaybeNull<T>> for Option<T> {
110    fn from(value: MaybeNull<T>) -> Self {
111        value.get()
112    }
113}
114
115impl<T: Nullable> TryFrom<Option<T>> for MaybeNull<T> {
116    type Error = MaybeNullError;
117
118    fn try_from(value: Option<T>) -> Result<Self, Self::Error> {
119        match value {
120            Some(value) if value.is_none() => Err(MaybeNullError::NoneValueInSome),
121            Some(value) => Ok(MaybeNull(value)),
122            None => Ok(MaybeNull(T::NONE)),
123        }
124    }
125}
126
127/// Error type for invalid `MaybeNull` conversions.
128#[derive(Clone, Copy, Debug, PartialEq, Eq)]
129pub enum MaybeNullError {
130    /// Attempted to wrap a none-equivalent value in `Some`.
131    NoneValueInSome,
132}
133
134impl core::fmt::Display for MaybeNullError {
135    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
136        match self {
137            Self::NoneValueInSome => {
138                write!(f, "cannot wrap None-equivalent value in Some")
139            }
140        }
141    }
142}
143
144/// ## Safety
145///
146/// `MaybeNull` is a transparent wrapper around a bytemuck `Pod` type `T` with
147/// identical data representation.
148#[cfg(feature = "bytemuck")]
149unsafe impl<T: Nullable + Pod> Pod for MaybeNull<T> {}
150
151/// ## Safety
152///
153/// `MaybeNull` is a transparent wrapper around a bytemuck `Pod` type `T` with
154/// identical data representation.
155#[cfg(feature = "bytemuck")]
156unsafe impl<T: Nullable + Zeroable> Zeroable for MaybeNull<T> {}
157
158#[cfg(feature = "serde")]
159impl<T> Serialize for MaybeNull<T>
160where
161    T: Nullable + Serialize,
162{
163    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
164    where
165        S: Serializer,
166    {
167        if self.0.is_none() {
168            serializer.serialize_none()
169        } else {
170            serializer.serialize_some(&self.0)
171        }
172    }
173}
174
175#[cfg(feature = "serde")]
176impl<'de, T> Deserialize<'de> for MaybeNull<T>
177where
178    T: Nullable + Deserialize<'de>,
179{
180    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
181    where
182        D: Deserializer<'de>,
183    {
184        let option = Option::<T>::deserialize(deserializer)?;
185        match option {
186            Some(value) if value.is_none() => Err(serde::de::Error::custom(
187                "Invalid MaybeNull encoding: Some(value) cannot equal the None marker.",
188            )),
189            Some(value) => Ok(MaybeNull(value)),
190            None => Ok(MaybeNull(T::NONE)),
191        }
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    impl Nullable for u64 {
200        const NONE: Self = 0;
201    }
202
203    #[test]
204    fn test_try_from_option() {
205        let some = Some(42u64);
206        assert_eq!(MaybeNull::try_from(some).unwrap(), MaybeNull(42u64));
207
208        let none: Option<u64> = None;
209        assert_eq!(MaybeNull::try_from(none).unwrap(), MaybeNull::from(0u64));
210
211        let invalid = Some(0u64);
212        assert_eq!(
213            MaybeNull::try_from(invalid).unwrap_err(),
214            MaybeNullError::NoneValueInSome,
215        );
216    }
217
218    #[test]
219    fn test_from_maybe_null() {
220        let some = MaybeNull::from(42u64);
221        let none = MaybeNull::from(0u64);
222
223        assert_eq!(Option::<u64>::from(some), Some(42));
224        assert_eq!(Option::<u64>::from(none), None);
225    }
226
227    #[test]
228    fn test_default() {
229        let def = MaybeNull::<u64>::default();
230        assert_eq!(def, MaybeNull(0u64));
231        assert_eq!(def.get(), None);
232    }
233
234    #[test]
235    fn test_copied() {
236        let some = MaybeNull::from(42u64);
237        assert_eq!(some.copied(), Some(42));
238
239        let none = MaybeNull::from(0u64);
240        assert_eq!(none.copied(), None);
241    }
242
243    #[test]
244    fn test_nullable_predicates() {
245        assert!(u64::NONE.is_none());
246        assert!(!u64::NONE.is_some());
247        assert!(8u64.is_some());
248        assert!(!8u64.is_none());
249    }
250
251    #[test]
252    fn test_as_ref() {
253        let some = MaybeNull::from(8u64);
254        assert_eq!(some.as_ref(), Some(&8u64));
255
256        let none = MaybeNull::from(u64::NONE);
257        assert_eq!(none.as_ref(), None);
258    }
259
260    #[test]
261    fn test_as_mut() {
262        let mut some = MaybeNull::from(3u64);
263        assert!(some.as_mut().is_some());
264        *some.as_mut().unwrap() = 4;
265        assert_eq!(some.get(), Some(4));
266
267        let mut none = MaybeNull::from(0u64);
268        assert!(none.as_mut().is_none());
269    }
270
271    #[derive(Clone, Debug, PartialEq)]
272    struct TestNonCopyNullable([u8; 4]);
273
274    impl Nullable for TestNonCopyNullable {
275        const NONE: Self = Self([0u8; 4]);
276    }
277
278    #[test]
279    fn test_cloned_with_non_copy_nullable() {
280        let some = MaybeNull::from(TestNonCopyNullable([1, 2, 3, 4]));
281        assert_eq!(some.cloned(), Some(TestNonCopyNullable([1, 2, 3, 4])));
282
283        let none = MaybeNull::from(TestNonCopyNullable::NONE);
284        assert_eq!(none.cloned(), None);
285    }
286
287    #[cfg(feature = "borsh")]
288    mod borsh_tests {
289        use {super::*, alloc::vec};
290
291        #[test]
292        fn test_borsh_roundtrip_u64() {
293            let some = MaybeNull::from(42u64);
294            let none = MaybeNull::from(0u64);
295
296            let some_bytes = borsh::to_vec(&some).unwrap();
297            let none_bytes = borsh::to_vec(&none).unwrap();
298
299            assert_eq!(some_bytes, 42u64.to_le_bytes().to_vec());
300            assert_eq!(none_bytes, vec![0; 8]);
301            assert_eq!(
302                borsh::from_slice::<MaybeNull<u64>>(&some_bytes).unwrap(),
303                some
304            );
305            assert_eq!(
306                borsh::from_slice::<MaybeNull<u64>>(&none_bytes).unwrap(),
307                none
308            );
309            assert!(borsh::from_slice::<MaybeNull<u64>>(&[]).is_err());
310        }
311    }
312
313    #[cfg(feature = "wincode")]
314    mod wincode_tests {
315        use {super::*, wincode::ZeroCopy};
316
317        #[test]
318        fn test_wincode_maybe_null_roundtrip_and_size() {
319            let some = MaybeNull::from(9u64);
320            let none = MaybeNull::from(0u64);
321
322            let some_bytes = wincode::serialize(&some).unwrap();
323            let none_bytes = wincode::serialize(&none).unwrap();
324
325            assert_eq!(some_bytes.len(), core::mem::size_of::<u64>());
326            assert_eq!(none_bytes.len(), core::mem::size_of::<u64>());
327            assert_eq!(some_bytes.as_slice(), &9u64.to_le_bytes());
328            assert_eq!(none_bytes.as_slice(), &0u64.to_le_bytes());
329
330            let some_roundtrip: MaybeNull<u64> = wincode::deserialize(&some_bytes).unwrap();
331            let none_roundtrip: MaybeNull<u64> = wincode::deserialize(&none_bytes).unwrap();
332            assert_eq!(some_roundtrip, some);
333            assert_eq!(none_roundtrip, none);
334
335            let some_zero_copy = MaybeNull::<u64>::from_bytes(&some_bytes).unwrap();
336            let none_zero_copy = MaybeNull::<u64>::from_bytes(&none_bytes).unwrap();
337            assert_eq!(some_zero_copy, &some);
338            assert_eq!(none_zero_copy, &none);
339        }
340
341        #[test]
342        fn test_wincode_maybe_null_rejects_truncated_input() {
343            assert!(wincode::deserialize::<MaybeNull<u64>>(&[]).is_err());
344            assert!(wincode::deserialize::<MaybeNull<u64>>(&[0; 7]).is_err());
345        }
346    }
347
348    #[cfg(feature = "serde")]
349    mod serde_tests {
350        use {super::*, alloc::string::ToString};
351
352        #[test]
353        fn test_serde_u64_some() {
354            let some = MaybeNull::from(7u64);
355            let serialized = serde_json::to_string(&some).unwrap();
356            assert_eq!(serialized, "7");
357            let deserialized = serde_json::from_str::<MaybeNull<u64>>(&serialized).unwrap();
358            assert_eq!(deserialized, some);
359        }
360
361        #[test]
362        fn test_serde_u64_none() {
363            let deserialized = serde_json::from_str::<MaybeNull<u64>>("null").unwrap();
364            assert_eq!(deserialized, MaybeNull::from(0));
365        }
366
367        #[test]
368        fn test_serde_u64_none_marker_error_message() {
369            let err = serde_json::from_str::<MaybeNull<u64>>("0").unwrap_err();
370            let message = err.to_string();
371            assert!(message.contains("MaybeNull encoding"));
372            assert!(message.contains("None marker"));
373        }
374
375        #[test]
376        fn test_serde_u64_reject_invalid_input() {
377            assert!(serde_json::from_str::<MaybeNull<u64>>("\"abc\"").is_err());
378            assert!(serde_json::from_str::<MaybeNull<u64>>("{}").is_err());
379        }
380    }
381
382    #[cfg(feature = "bytemuck")]
383    mod bytemuck_tests {
384        use super::*;
385
386        #[test]
387        fn test_maybe_null_u64() {
388            let some = MaybeNull::from(42u64);
389            assert_eq!(some.get(), Some(42));
390
391            let none = MaybeNull::from(0u64);
392            assert_eq!(none.get(), None);
393
394            let bytes = 42u64.to_le_bytes();
395            let value: &MaybeNull<u64> = bytemuck::from_bytes(&bytes);
396            assert_eq!(*value, MaybeNull::from(42u64));
397
398            let zero_bytes = 0u64.to_le_bytes();
399            let value: &MaybeNull<u64> = bytemuck::from_bytes(&zero_bytes);
400            assert_eq!(*value, MaybeNull::from(0u64));
401            assert_eq!(value.get(), None);
402        }
403
404        #[test]
405        fn test_maybe_null_from_bytes_errors() {
406            assert!(bytemuck::try_from_bytes::<MaybeNull<u64>>(&[]).is_err());
407            assert!(bytemuck::try_from_bytes::<MaybeNull<u64>>(&[0; 1]).is_err());
408        }
409    }
410}