vortex_dht/
node.rs

1use std::{
2    net::SocketAddr,
3    ops::{Add, Deref, Sub},
4    time::Duration,
5};
6
7use bytes::Bytes;
8use rand::Rng;
9use serde_derive::{Deserialize, Serialize};
10use time::OffsetDateTime;
11
12// TODO: migrate to u128 + u32 BE endian large nums.
13// (Can use lexographical order)
14#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Deserialize, Serialize)]
15pub struct NodeId([u8; 20]);
16
17pub const ID_ZERO: NodeId = NodeId([0; 20]);
18pub const ID_MAX: NodeId = NodeId([0xFF; 20]);
19
20impl NodeId {
21    // TODO: don't change in place?
22    pub fn halve(&mut self) {
23        let mut carry = false;
24        self.0.iter_mut().for_each(|byte| {
25            let mut new_byte = *byte >> 1;
26            if carry {
27                new_byte |= 0b1000_0000;
28            }
29            carry = *byte & 0b0000_0001 != 0;
30            *byte = new_byte;
31        });
32    }
33
34    // a bit odd to return another node id here
35    pub fn distance(&self, other: &NodeId) -> NodeId {
36        // Almost optimal asm generated but can be improved
37        let mut dist = [0; 20];
38        self.0
39            .iter()
40            .zip(other.0.iter())
41            .zip(dist.iter_mut())
42            .for_each(|((a, b), res)| *res = a ^ b);
43        NodeId(dist)
44    }
45
46    // TODO: duplicated with deref impl
47    pub fn as_bytes(&self) -> [u8; 20] {
48        self.0
49    }
50
51    /// Generates a new node id in range [min, max)
52    pub fn new_in_range(min: &NodeId, max: &NodeId) -> NodeId {
53        let mut delta = max - min;
54        let mut rng = rand::thread_rng();
55        for delta_byte in delta.0.iter_mut() {
56            *delta_byte = (rng.gen::<f32>() * *delta_byte as f32) as u8;
57        }
58        &delta + min
59    }
60}
61
62impl Add for &NodeId {
63    type Output = NodeId;
64
65    // TODO optimize with arch intrinsics by first converting to u32
66    fn add(self, rhs: Self) -> Self::Output {
67        let mut carry = false;
68        let mut result = [0; 20];
69        self.0
70            .iter()
71            .rev()
72            .zip(rhs.0.iter().rev())
73            .zip(result.iter_mut().rev())
74            .for_each(|((own, other), res)| {
75                let (num, new_carry) = own.overflowing_add(*other);
76                *res = num;
77                if carry {
78                    let (num, extra_carry) = res.overflowing_add(1);
79                    *res = num;
80                    carry = new_carry | extra_carry;
81                } else {
82                    carry = new_carry;
83                }
84            });
85        NodeId(result)
86    }
87}
88
89impl Sub for &NodeId {
90    type Output = NodeId;
91
92    // TODO optimize with arch intrinsics by first converting to u32
93    fn sub(self, rhs: Self) -> Self::Output {
94        let mut carry = false;
95        let mut result = [0; 20];
96        self.0
97            .iter()
98            .rev()
99            .zip(rhs.0.iter().rev())
100            .zip(result.iter_mut().rev())
101            .for_each(|((own, other), res)| {
102                let (num, new_carry) = own.overflowing_sub(*other);
103                *res = num;
104                if carry {
105                    let (num, extra_carry) = res.overflowing_sub(1);
106                    *res = num;
107                    carry = new_carry | extra_carry;
108                } else {
109                    carry = new_carry;
110                }
111            });
112        NodeId(result)
113    }
114}
115
116#[inline]
117pub fn midpoint(low: &NodeId, high: &NodeId) -> NodeId {
118    assert!(low < high);
119    let mut diff = high - low;
120    diff.halve();
121    low + &diff
122}
123
124impl From<Bytes> for NodeId {
125    fn from(bytes: Bytes) -> Self {
126        bytes[..].into()
127    }
128}
129
130impl From<&[u8]> for NodeId {
131    fn from(slice: &[u8]) -> Self {
132        // use maybe uninit
133        let mut id = [0; 20];
134        id.copy_from_slice(slice);
135        NodeId(id)
136    }
137}
138
139impl Deref for NodeId {
140    type Target = [u8; 20];
141
142    fn deref(&self) -> &Self::Target {
143        &self.0
144    }
145}
146
147impl core::fmt::Debug for NodeId {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        f.debug_tuple("NodeId")
150            .field(&format!("{:02x?}", &self.0))
151            .finish()
152    }
153}
154
155#[derive(Debug, Copy, Clone, PartialEq, Deserialize, Serialize)]
156pub enum NodeStatus {
157    Good,
158    Bad,
159    Unknown,
160}
161
162#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
163pub struct Node {
164    pub id: NodeId,
165    pub addr: SocketAddr,
166    pub last_status: NodeStatus,
167    pub last_seen: OffsetDateTime,
168}
169
170impl Node {
171    // TODO: maybe &mut and update it here
172    pub fn current_status(&self) -> NodeStatus {
173        match self.last_status {
174            NodeStatus::Good => {
175                let stale =
176                    OffsetDateTime::now_utc() - self.last_seen > Duration::from_secs(15 * 60);
177                if stale {
178                    NodeStatus::Unknown
179                } else {
180                    NodeStatus::Good
181                }
182            }
183            NodeStatus::Unknown => NodeStatus::Unknown,
184            NodeStatus::Bad => NodeStatus::Bad,
185        }
186    }
187}
188
189#[cfg(test)]
190mod test {
191    use num_bigint::BigInt;
192
193    use super::*;
194
195    impl From<BigInt> for NodeId {
196        fn from(bigint: BigInt) -> Self {
197            let (_, bytes) = bigint.to_bytes_be();
198            bytes.as_slice().into()
199        }
200    }
201
202    #[test]
203    fn test_addition() {
204        // Sanity check with big int
205        let bigint_a = BigInt::new(
206            num_bigint::Sign::Plus,
207            // LE bytes
208            vec![u32::MAX, u32::MAX, u32::MAX - 1, u32::MAX, u32::MAX - 1],
209        );
210
211        let bigint_b = BigInt::new(num_bigint::Sign::Plus, vec![0, 0, 2, 1, 0]);
212
213        let expected: BigInt = bigint_a + bigint_b;
214
215        assert_eq!(
216            BigInt::new(
217                num_bigint::Sign::Plus,
218                vec![u32::MAX, u32::MAX, 0, 1, u32::MAX],
219            ),
220            expected
221        );
222
223        let expected: NodeId = expected.into();
224
225        // BE bytes
226        let nodeid_a = NodeId::from(
227            [
228                0xFF,
229                0xFF,
230                0xFF,
231                0xFF - 1,
232                0xFF,
233                0xFF,
234                0xFF,
235                0xFF,
236                0xFF,
237                0xFF,
238                0xFF,
239                0xFF - 1,
240                0xFF,
241                0xFF,
242                0xFF,
243                0xFF,
244                0xFF,
245                0xFF,
246                0xFF,
247                0xFF,
248            ]
249            .as_slice(),
250        );
251
252        let nodeid_b =
253            NodeId::from([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0].as_slice());
254
255        let actual = &nodeid_b + &nodeid_a;
256
257        assert_eq!(expected, actual);
258    }
259
260    #[test]
261    fn find_midpoint() {
262        let high = BigInt::new(
263            num_bigint::Sign::Plus,
264            vec![u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX],
265        );
266
267        let low = high.clone() / 2;
268
269        let mid: BigInt = (high + low) / 2;
270
271        let mid_id: NodeId = mid.into();
272
273        let high = ID_MAX;
274        let mut low = ID_MAX;
275        low.halve();
276
277        let mut calculated_mid = &high - &low;
278        calculated_mid.halve();
279        let calculated_mid = &low + &calculated_mid;
280
281        assert_eq!(mid_id, calculated_mid);
282    }
283
284    #[test]
285    fn test_subtraction() {
286        // Sanity check with big int
287        let bigint_a = BigInt::new(
288            num_bigint::Sign::Plus,
289            // LE bytes
290            vec![u32::MAX, u32::MAX, 0, 0, 1],
291        );
292
293        let bigint_b = BigInt::new(num_bigint::Sign::Plus, vec![0, 0, 2, 1, 0]);
294
295        let expected: BigInt = bigint_a - bigint_b;
296
297        assert_eq!(
298            BigInt::new(
299                num_bigint::Sign::Plus,
300                vec![u32::MAX, u32::MAX, u32::MAX - 1, u32::MAX - 1],
301            ),
302            expected
303        );
304
305        let (_, mut expected_bytes) = expected.to_u32_digits();
306        expected_bytes.push(0);
307        let expected_bytes: Vec<u8> = expected_bytes
308            .iter()
309            .rev()
310            .flat_map(|num| num.to_be_bytes())
311            .collect();
312        let expected: NodeId = expected_bytes.as_slice().into();
313
314        // BE bytes
315        let nodeid_a = NodeId::from(
316            [
317                0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
318            ]
319            .as_slice(),
320        );
321
322        let nodeid_b =
323            NodeId::from([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0].as_slice());
324
325        let actual = &nodeid_a - &nodeid_b;
326
327        assert_eq!(expected, actual);
328    }
329}