1use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
2use std::borrow::Cow;
3use std::convert::TryFrom;
4use std::str::FromStr;
5use std::{fmt, str};
6
7const NODE_ID_LENGTH: usize = 20;
8
9#[derive(Clone, Debug, thiserror::Error, PartialEq, Serialize, Deserialize)]
10#[error("NodeId `{original_str}` parsing error: {msg}")]
11pub struct ParseError {
12 original_str: String,
13 msg: String,
14}
15
16impl ParseError {
17 fn new(original_str: impl Into<String>, msg: impl Into<String>) -> Self {
18 Self {
19 original_str: original_str.into(),
20 msg: msg.into(),
21 }
22 }
23}
24
25#[derive(Clone, Debug, thiserror::Error, PartialEq, Serialize, Deserialize)]
26#[error("NodeId parsing error: {msg}")]
27pub struct InvalidLengthError {
28 msg: String,
29}
30
31#[derive(Clone, Copy, Hash, PartialEq, Eq)]
33pub struct NodeId {
34 inner: [u8; NODE_ID_LENGTH],
35}
36
37impl NodeId {
38 #[inline(always)]
39 fn with_hex<F, R>(&self, f: F) -> R
40 where
41 F: FnOnce(&str) -> R,
42 {
43 let mut hex_str = [0u8; 42];
44
45 hex_str[0] = b'0';
46 hex_str[1] = b'x';
47
48 let mut ptr = 2;
49 for it in &self.inner {
50 let hi = (it >> 4) & 0xfu8;
51 let lo = it & 0xf;
52 hex_str[ptr] = HEX_CHARS[hi as usize];
53 hex_str[ptr + 1] = HEX_CHARS[lo as usize];
54 ptr += 2;
55 }
56 assert_eq!(ptr, hex_str.len());
57
58 let hex_str = unsafe { str::from_utf8_unchecked(&hex_str) };
59 f(hex_str)
60 }
61
62 #[inline]
63 pub fn into_array(self) -> [u8; NODE_ID_LENGTH] {
64 self.inner
65 }
66}
67
68impl Default for NodeId {
69 fn default() -> Self {
70 NodeId {
71 inner: [0; NODE_ID_LENGTH],
72 }
73 }
74}
75
76impl AsRef<[u8]> for NodeId {
77 fn as_ref(&self) -> &[u8] {
78 &self.inner
79 }
80}
81
82impl AsRef<[u8; NODE_ID_LENGTH]> for NodeId {
83 fn as_ref(&self) -> &[u8; NODE_ID_LENGTH] {
84 &self.inner
85 }
86}
87
88impl From<[u8; NODE_ID_LENGTH]> for NodeId {
89 fn from(inner: [u8; NODE_ID_LENGTH]) -> Self {
90 NodeId { inner }
91 }
92}
93
94impl TryFrom<&Vec<u8>> for NodeId {
95 type Error = InvalidLengthError;
96 fn try_from(inner: &Vec<u8>) -> Result<Self, InvalidLengthError> {
97 if inner.len() != NODE_ID_LENGTH {
98 return Err(InvalidLengthError {
99 msg: format!(
100 "Invalid length: {}, NodeId requires {}.",
101 inner.len(),
102 NODE_ID_LENGTH
103 ),
104 });
105 }
106 Ok(Self::from(inner.as_ref()))
107 }
108}
109
110impl<'a> From<&'a [u8]> for NodeId {
111 fn from(it: &'a [u8]) -> Self {
112 let mut inner = [0; NODE_ID_LENGTH];
113 inner.copy_from_slice(it);
114
115 NodeId { inner }
116 }
117}
118
119impl<'a> From<Cow<'a, [u8]>> for NodeId {
120 fn from(it: Cow<'a, [u8]>) -> Self {
121 it.as_ref().into()
122 }
123}
124
125#[inline]
126fn hex_to_dec(hex: u8, s: &str) -> Result<u8, ParseError> {
127 match hex {
128 b'A'..=b'F' => Ok(hex - b'A' + 10),
129 b'a'..=b'f' => Ok(hex - b'a' + 10),
130 b'0'..=b'9' => Ok(hex - b'0'),
131 _ => Err(ParseError::new(
132 s,
133 format!("expected hex chars, but got: `{}`", char::from(hex)),
134 )),
135 }
136}
137
138impl str::FromStr for NodeId {
139 type Err = ParseError;
140
141 fn from_str(s: &str) -> Result<Self, ParseError> {
142 let bytes = s.as_bytes();
143
144 if bytes.len() != 2 + NODE_ID_LENGTH * 2 {
146 return Err(ParseError::new(s, "expected length is 42 chars"));
147 }
148
149 if bytes[0] != b'0' || bytes[1] != b'x' {
150 return Err(ParseError::new(s, "expected 0x prefix"));
151 }
152
153 let mut inner = [0u8; NODE_ID_LENGTH];
154 let mut p = 0;
155
156 for b in bytes[2..].chunks(2) {
157 let (hi, lo) = (hex_to_dec(b[0], s)?, hex_to_dec(b[1], s)?);
158 inner[p] = (hi << 4) | lo;
159 p += 1;
160 }
161 assert_eq!(p, NODE_ID_LENGTH);
162
163 Ok(NodeId { inner })
164 }
165}
166
167static HEX_CHARS: [u8; 16] = [
168 b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'a', b'b', b'c', b'd', b'e', b'f',
169];
170
171impl Serialize for NodeId {
172 #[inline]
173 fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
174 where
175 S: Serializer,
176 {
177 self.with_hex(|hex_str| serializer.serialize_str(hex_str))
178 }
179}
180
181impl fmt::Debug for NodeId {
182 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183 write!(f, "{}", self)
184 }
185}
186
187impl fmt::Display for NodeId {
188 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
189 self.with_hex(|hex_str| write!(f, "{}", hex_str))
190 }
191}
192
193struct NodeIdVisit;
194
195impl<'de> de::Visitor<'de> for NodeIdVisit {
196 type Value = NodeId;
197
198 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
199 write!(formatter, "a nodeId")
200 }
201
202 fn visit_str<E>(self, v: &str) -> Result<<Self as de::Visitor<'de>>::Value, E>
203 where
204 E: de::Error,
205 {
206 match NodeId::from_str(v) {
207 Ok(node_id) => Ok(node_id),
208 Err(_) => Err(de::Error::custom("invalid format")),
209 }
210 }
211
212 fn visit_bytes<E>(self, v: &[u8]) -> Result<<Self as de::Visitor<'de>>::Value, E>
213 where
214 E: de::Error,
215 {
216 if v.len() == NODE_ID_LENGTH {
217 let mut inner = [0u8; NODE_ID_LENGTH];
218 inner.copy_from_slice(v);
219 Ok(NodeId { inner })
220 } else {
221 Err(de::Error::custom("invalid format"))
222 }
223 }
224}
225
226impl<'de> Deserialize<'de> for NodeId {
227 fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
228 where
229 D: Deserializer<'de>,
230 {
231 deserializer.deserialize_str(NodeIdVisit)
232 }
233}
234
235#[cfg(feature = "with-diesel")]
236#[allow(dead_code)]
237mod sql {
238 use super::NodeId;
239 use diesel::backend::Backend;
240 use diesel::deserialize::FromSql;
241 use diesel::expression::bound::Bound;
242 use diesel::expression::AsExpression;
243 use diesel::serialize::{IsNull, Output, ToSql};
244 use diesel::sql_types::Text;
245 use diesel::*;
246
247 impl AsExpression<Text> for NodeId {
248 type Expression = Bound<Text, String>;
249
250 fn as_expression(self) -> Self::Expression {
251 Bound::new(self.to_string())
252 }
253 }
254
255 impl AsExpression<Text> for &NodeId {
256 type Expression = Bound<Text, String>;
257
258 fn as_expression(self) -> Self::Expression {
259 Bound::new(self.to_string())
260 }
261 }
262
263 impl<DB> FromSql<Text, DB> for NodeId
264 where
265 DB: Backend,
266 String: FromSql<Text, DB>,
267 {
268 fn from_sql(bytes: Option<&<DB as Backend>::RawValue>) -> deserialize::Result<Self> {
269 let s: String = FromSql::from_sql(bytes)?;
270 Ok(s.parse()?)
271 }
272 }
273
274 impl<DB> ToSql<Text, DB> for NodeId
275 where
276 DB: Backend,
277 for<'b> &'b str: ToSql<Text, DB>,
278 {
279 fn to_sql<W: std::io::Write>(
280 &self,
281 out: &mut Output<'_, W, DB>,
282 ) -> deserialize::Result<IsNull> {
283 self.with_hex(move |s: &str| ToSql::<Text, DB>::to_sql(s, out))
284 }
285 }
286
287 #[derive(FromSqlRow)]
288 #[diesel(foreign_derive)]
289 struct NodeIdProxy(NodeId);
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use std::convert::TryInto;
296
297 #[test]
298 fn parse_empty_str() {
299 assert_eq!(
300 "".parse::<NodeId>().unwrap_err().to_string(),
301 "NodeId `` parsing error: expected length is 42 chars".to_string()
302 );
303 }
304
305 #[test]
306 fn parse_short_str() {
307 assert_eq!(
308 "short".parse::<NodeId>().unwrap_err().to_string(),
309 "NodeId `short` parsing error: expected length is 42 chars".to_string()
310 );
311 }
312
313 #[test]
314 fn parse_long_str() {
315 assert_eq!(
316 "0123456789012345678901234567890123456789123"
317 .parse::<NodeId>()
318 .unwrap_err()
319 .to_string(),
320 "NodeId `0123456789012345678901234567890123456789123` parsing error: expected length is 42 chars".to_string()
321 );
322 }
323
324 #[test]
325 fn parse_wo_0x_str() {
326 assert_eq!(
327 "012345678901234567890123456789012345678912"
328 .parse::<NodeId>()
329 .unwrap_err()
330 .to_string(),
331 "NodeId `012345678901234567890123456789012345678912` parsing error: expected 0x prefix"
332 .to_string()
333 );
334 }
335
336 #[test]
337 fn parse_non_hex_str() {
338 assert_eq!(
339 "0xx000000000000000000000000000000000000000"
340 .parse::<NodeId>()
341 .unwrap_err()
342 .to_string(),
343 "NodeId `0xx000000000000000000000000000000000000000` parsing error: expected hex chars, but got: `x`".to_string()
344 );
345 }
346
347 #[test]
348 fn parse_proper_str() {
349 assert_eq!(
350 "0xbabe000000000000000000000000000000000000"
351 .parse::<NodeId>()
352 .unwrap()
353 .to_string(),
354 "0xbabe000000000000000000000000000000000000".to_string()
355 );
356 }
357
358 #[test]
359 fn try_from_too_long_vec() {
360 let test_vec: Vec<u8> = vec![
361 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
362 ];
363 let result: Result<NodeId, InvalidLengthError> = (&test_vec).try_into();
364 assert_eq!(
365 result.unwrap_err().to_string(),
366 format!(
367 "NodeId parsing error: Invalid length: 22, NodeId requires {}.",
368 NODE_ID_LENGTH
369 )
370 );
371 }
372}