1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::fmt::{self, Debug, Display, Formatter};
3use thiserror::Error;
4use utoipa::ToSchema;
5
6#[derive(Clone, PartialEq, Eq, Hash, Copy, ToSchema)]
9pub struct ByteArray<const N: usize>([u8; N]);
10
11impl<const N: usize> ByteArray<N> {
12 pub fn new(data: [u8; N]) -> Self {
13 ByteArray(data)
14 }
15
16 pub fn as_slice(&self) -> &[u8] {
17 &self.0
18 }
19
20 pub fn as_mut_slice(&mut self) -> &mut [u8] {
21 &mut self.0
22 }
23
24 pub fn into_inner(self) -> [u8; N] {
25 self.0
26 }
27
28 pub fn len(&self) -> usize {
29 N
30 }
31
32 pub fn is_empty(&self) -> bool {
33 N == 0
34 }
35}
36
37impl<const N: usize> TryFrom<Vec<u8>> for ByteArray<N> {
38 type Error = ByteArrayError<N>;
39
40 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
41 if value.len() != N {
42 return Err(ByteArrayError { length: N });
43 }
44 let mut array = [0u8; N];
45 array.copy_from_slice(&value);
46 Ok(ByteArray(array))
47 }
48}
49
50impl<const N: usize> Display for ByteArray<N> {
51 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
52 let hex_string = const_hex::encode(self.0);
54 write!(f, "0x{hex_string}")
55 }
56}
57
58impl<const N: usize> Debug for ByteArray<N> {
59 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
60 Display::fmt(self, f)
61 }
62}
63
64impl<const N: usize> Serialize for ByteArray<N> {
65 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
66 where
67 S: Serializer,
68 {
69 let hex_string = self.to_string();
71 serializer.serialize_str(&hex_string)
72 }
73}
74
75impl<'de, const N: usize> Deserialize<'de> for ByteArray<N> {
76 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
77 where
78 D: Deserializer<'de>,
79 {
80 let s = String::deserialize(deserializer)?;
82 let bytes = const_hex::decode(&s).map_err(|e| serde::de::Error::custom(e.to_string()))?;
83
84 let array: [u8; N] = bytes
86 .try_into()
87 .map_err(|_| serde::de::Error::custom("invalid hex length"))?;
88
89 Ok(ByteArray(array))
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[test]
98 fn test_display() {
99 let data = ByteArray::<4>([0xDE, 0xAD, 0xBE, 0xEF]);
100 assert_eq!(format!("{data}"), "0xdeadbeef");
101 }
102
103 #[test]
104 fn test_serde() {
105 let data = ByteArray::<4>([0xDE, 0xAD, 0xBE, 0xEF]);
106
107 let serialized = serde_json::to_string(&data).unwrap();
109 assert_eq!(serialized, "\"0xdeadbeef\"");
111
112 let deserialized: ByteArray<4> = serde_json::from_str("\"0xdeadbeef\"").unwrap();
114 assert_eq!(deserialized.0, [0xDE, 0xAD, 0xBE, 0xEF]);
115 }
116
117 #[test]
118 fn test_debug() {
119 let data = ByteArray::<4>([0xDE, 0xAD, 0xBE, 0xEF]);
120 assert_eq!(format!("{data:?}"), "0xdeadbeef");
121 assert_eq!(format!("{data:#?}"), "0xdeadbeef");
122 }
123}
124
125#[derive(Error, Debug)]
126#[error("ByteArray<{N}> must be exactly {N} bytes long")]
127pub struct ByteArrayError<const N: usize> {
128 pub length: usize,
129}