wavs_types/
bytes.rs

1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::fmt::{self, Debug, Display, Formatter};
3use thiserror::Error;
4use utoipa::ToSchema;
5
6/// A newtype that wraps a `[u8; N]` using const generics.
7/// and is serialized as a `0x` prefixed hex string.
8#[derive(Clone, PartialEq, Eq, Hash, Copy, ToSchema)]
9pub struct ByteArray<const N: usize>([u8; N]);
10
11impl<const N: usize> ByteArray<N> {
12    pub fn new(data: [u8; N]) -> Self {
13        ByteArray(data)
14    }
15
16    pub fn as_slice(&self) -> &[u8] {
17        &self.0
18    }
19
20    pub fn as_mut_slice(&mut self) -> &mut [u8] {
21        &mut self.0
22    }
23
24    pub fn into_inner(self) -> [u8; N] {
25        self.0
26    }
27
28    pub fn len(&self) -> usize {
29        N
30    }
31
32    pub fn is_empty(&self) -> bool {
33        N == 0
34    }
35}
36
37impl<const N: usize> TryFrom<Vec<u8>> for ByteArray<N> {
38    type Error = ByteArrayError<N>;
39
40    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
41        if value.len() != N {
42            return Err(ByteArrayError { length: N });
43        }
44        let mut array = [0u8; N];
45        array.copy_from_slice(&value);
46        Ok(ByteArray(array))
47    }
48}
49
50impl<const N: usize> Display for ByteArray<N> {
51    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
52        // Encode the byte array as hex using `const_hex::encode`.
53        let hex_string = const_hex::encode(self.0);
54        write!(f, "0x{hex_string}")
55    }
56}
57
58impl<const N: usize> Debug for ByteArray<N> {
59    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
60        Display::fmt(self, f)
61    }
62}
63
64impl<const N: usize> Serialize for ByteArray<N> {
65    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
66    where
67        S: Serializer,
68    {
69        // Serialize the hex string
70        let hex_string = self.to_string();
71        serializer.serialize_str(&hex_string)
72    }
73}
74
75impl<'de, const N: usize> Deserialize<'de> for ByteArray<N> {
76    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
77    where
78        D: Deserializer<'de>,
79    {
80        // Deserialize as a string, then decode that string from hex.
81        let s = String::deserialize(deserializer)?;
82        let bytes = const_hex::decode(&s).map_err(|e| serde::de::Error::custom(e.to_string()))?;
83
84        // Ensure the decoded bytes have the correct length.
85        let array: [u8; N] = bytes
86            .try_into()
87            .map_err(|_| serde::de::Error::custom("invalid hex length"))?;
88
89        Ok(ByteArray(array))
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    #[test]
98    fn test_display() {
99        let data = ByteArray::<4>([0xDE, 0xAD, 0xBE, 0xEF]);
100        assert_eq!(format!("{data}"), "0xdeadbeef");
101    }
102
103    #[test]
104    fn test_serde() {
105        let data = ByteArray::<4>([0xDE, 0xAD, 0xBE, 0xEF]);
106
107        // Test serialization
108        let serialized = serde_json::to_string(&data).unwrap();
109        // Expect a JSON string: "deadbeef"
110        assert_eq!(serialized, "\"0xdeadbeef\"");
111
112        // Test deserialization
113        let deserialized: ByteArray<4> = serde_json::from_str("\"0xdeadbeef\"").unwrap();
114        assert_eq!(deserialized.0, [0xDE, 0xAD, 0xBE, 0xEF]);
115    }
116
117    #[test]
118    fn test_debug() {
119        let data = ByteArray::<4>([0xDE, 0xAD, 0xBE, 0xEF]);
120        assert_eq!(format!("{data:?}"), "0xdeadbeef");
121        assert_eq!(format!("{data:#?}"), "0xdeadbeef");
122    }
123}
124
125#[derive(Error, Debug)]
126#[error("ByteArray<{N}> must be exactly {N} bytes long")]
127pub struct ByteArrayError<const N: usize> {
128    pub length: usize,
129}