1use std::{
2 net::SocketAddr,
3 ops::{Add, Deref, Sub},
4 time::Duration,
5};
6
7use bytes::Bytes;
8use rand::Rng;
9use serde_derive::{Deserialize, Serialize};
10use time::OffsetDateTime;
11
12#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Deserialize, Serialize)]
15pub struct NodeId([u8; 20]);
16
17pub const ID_ZERO: NodeId = NodeId([0; 20]);
18pub const ID_MAX: NodeId = NodeId([0xFF; 20]);
19
20impl NodeId {
21 pub fn halve(&mut self) {
23 let mut carry = false;
24 self.0.iter_mut().for_each(|byte| {
25 let mut new_byte = *byte >> 1;
26 if carry {
27 new_byte |= 0b1000_0000;
28 }
29 carry = *byte & 0b0000_0001 != 0;
30 *byte = new_byte;
31 });
32 }
33
34 pub fn distance(&self, other: &NodeId) -> NodeId {
36 let mut dist = [0; 20];
38 self.0
39 .iter()
40 .zip(other.0.iter())
41 .zip(dist.iter_mut())
42 .for_each(|((a, b), res)| *res = a ^ b);
43 NodeId(dist)
44 }
45
46 pub fn as_bytes(&self) -> [u8; 20] {
48 self.0
49 }
50
51 pub fn new_in_range(min: &NodeId, max: &NodeId) -> NodeId {
53 let mut delta = max - min;
54 let mut rng = rand::thread_rng();
55 for delta_byte in delta.0.iter_mut() {
56 *delta_byte = (rng.gen::<f32>() * *delta_byte as f32) as u8;
57 }
58 &delta + min
59 }
60}
61
62impl Add for &NodeId {
63 type Output = NodeId;
64
65 fn add(self, rhs: Self) -> Self::Output {
67 let mut carry = false;
68 let mut result = [0; 20];
69 self.0
70 .iter()
71 .rev()
72 .zip(rhs.0.iter().rev())
73 .zip(result.iter_mut().rev())
74 .for_each(|((own, other), res)| {
75 let (num, new_carry) = own.overflowing_add(*other);
76 *res = num;
77 if carry {
78 let (num, extra_carry) = res.overflowing_add(1);
79 *res = num;
80 carry = new_carry | extra_carry;
81 } else {
82 carry = new_carry;
83 }
84 });
85 NodeId(result)
86 }
87}
88
89impl Sub for &NodeId {
90 type Output = NodeId;
91
92 fn sub(self, rhs: Self) -> Self::Output {
94 let mut carry = false;
95 let mut result = [0; 20];
96 self.0
97 .iter()
98 .rev()
99 .zip(rhs.0.iter().rev())
100 .zip(result.iter_mut().rev())
101 .for_each(|((own, other), res)| {
102 let (num, new_carry) = own.overflowing_sub(*other);
103 *res = num;
104 if carry {
105 let (num, extra_carry) = res.overflowing_sub(1);
106 *res = num;
107 carry = new_carry | extra_carry;
108 } else {
109 carry = new_carry;
110 }
111 });
112 NodeId(result)
113 }
114}
115
116#[inline]
117pub fn midpoint(low: &NodeId, high: &NodeId) -> NodeId {
118 assert!(low < high);
119 let mut diff = high - low;
120 diff.halve();
121 low + &diff
122}
123
124impl From<Bytes> for NodeId {
125 fn from(bytes: Bytes) -> Self {
126 bytes[..].into()
127 }
128}
129
130impl From<&[u8]> for NodeId {
131 fn from(slice: &[u8]) -> Self {
132 let mut id = [0; 20];
134 id.copy_from_slice(slice);
135 NodeId(id)
136 }
137}
138
139impl Deref for NodeId {
140 type Target = [u8; 20];
141
142 fn deref(&self) -> &Self::Target {
143 &self.0
144 }
145}
146
147impl core::fmt::Debug for NodeId {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 f.debug_tuple("NodeId")
150 .field(&format!("{:02x?}", &self.0))
151 .finish()
152 }
153}
154
155#[derive(Debug, Copy, Clone, PartialEq, Deserialize, Serialize)]
156pub enum NodeStatus {
157 Good,
158 Bad,
159 Unknown,
160}
161
162#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
163pub struct Node {
164 pub id: NodeId,
165 pub addr: SocketAddr,
166 pub last_status: NodeStatus,
167 pub last_seen: OffsetDateTime,
168}
169
170impl Node {
171 pub fn current_status(&self) -> NodeStatus {
173 match self.last_status {
174 NodeStatus::Good => {
175 let stale =
176 OffsetDateTime::now_utc() - self.last_seen > Duration::from_secs(15 * 60);
177 if stale {
178 NodeStatus::Unknown
179 } else {
180 NodeStatus::Good
181 }
182 }
183 NodeStatus::Unknown => NodeStatus::Unknown,
184 NodeStatus::Bad => NodeStatus::Bad,
185 }
186 }
187}
188
189#[cfg(test)]
190mod test {
191 use num_bigint::BigInt;
192
193 use super::*;
194
195 impl From<BigInt> for NodeId {
196 fn from(bigint: BigInt) -> Self {
197 let (_, bytes) = bigint.to_bytes_be();
198 bytes.as_slice().into()
199 }
200 }
201
202 #[test]
203 fn test_addition() {
204 let bigint_a = BigInt::new(
206 num_bigint::Sign::Plus,
207 vec![u32::MAX, u32::MAX, u32::MAX - 1, u32::MAX, u32::MAX - 1],
209 );
210
211 let bigint_b = BigInt::new(num_bigint::Sign::Plus, vec![0, 0, 2, 1, 0]);
212
213 let expected: BigInt = bigint_a + bigint_b;
214
215 assert_eq!(
216 BigInt::new(
217 num_bigint::Sign::Plus,
218 vec![u32::MAX, u32::MAX, 0, 1, u32::MAX],
219 ),
220 expected
221 );
222
223 let expected: NodeId = expected.into();
224
225 let nodeid_a = NodeId::from(
227 [
228 0xFF,
229 0xFF,
230 0xFF,
231 0xFF - 1,
232 0xFF,
233 0xFF,
234 0xFF,
235 0xFF,
236 0xFF,
237 0xFF,
238 0xFF,
239 0xFF - 1,
240 0xFF,
241 0xFF,
242 0xFF,
243 0xFF,
244 0xFF,
245 0xFF,
246 0xFF,
247 0xFF,
248 ]
249 .as_slice(),
250 );
251
252 let nodeid_b =
253 NodeId::from([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0].as_slice());
254
255 let actual = &nodeid_b + &nodeid_a;
256
257 assert_eq!(expected, actual);
258 }
259
260 #[test]
261 fn find_midpoint() {
262 let high = BigInt::new(
263 num_bigint::Sign::Plus,
264 vec![u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX],
265 );
266
267 let low = high.clone() / 2;
268
269 let mid: BigInt = (high + low) / 2;
270
271 let mid_id: NodeId = mid.into();
272
273 let high = ID_MAX;
274 let mut low = ID_MAX;
275 low.halve();
276
277 let mut calculated_mid = &high - &low;
278 calculated_mid.halve();
279 let calculated_mid = &low + &calculated_mid;
280
281 assert_eq!(mid_id, calculated_mid);
282 }
283
284 #[test]
285 fn test_subtraction() {
286 let bigint_a = BigInt::new(
288 num_bigint::Sign::Plus,
289 vec![u32::MAX, u32::MAX, 0, 0, 1],
291 );
292
293 let bigint_b = BigInt::new(num_bigint::Sign::Plus, vec![0, 0, 2, 1, 0]);
294
295 let expected: BigInt = bigint_a - bigint_b;
296
297 assert_eq!(
298 BigInt::new(
299 num_bigint::Sign::Plus,
300 vec![u32::MAX, u32::MAX, u32::MAX - 1, u32::MAX - 1],
301 ),
302 expected
303 );
304
305 let (_, mut expected_bytes) = expected.to_u32_digits();
306 expected_bytes.push(0);
307 let expected_bytes: Vec<u8> = expected_bytes
308 .iter()
309 .rev()
310 .flat_map(|num| num.to_be_bytes())
311 .collect();
312 let expected: NodeId = expected_bytes.as_slice().into();
313
314 let nodeid_a = NodeId::from(
316 [
317 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
318 ]
319 .as_slice(),
320 );
321
322 let nodeid_b =
323 NodeId::from([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0].as_slice());
324
325 let actual = &nodeid_a - &nodeid_b;
326
327 assert_eq!(expected, actual);
328 }
329}