tycho_network/dht/
mod.rs

1use std::sync::{Arc, Mutex};
2
3use anyhow::Result;
4use bytes::{Buf, Bytes};
5use rand::RngCore;
6use tl_proto::TlRead;
7use tokio::sync::{broadcast, Notify};
8use tycho_util::realloc_box_enum;
9use tycho_util::time::now_sec;
10
11pub use self::config::DhtConfig;
12pub use self::peer_resolver::{
13    PeerResolver, PeerResolverBuilder, PeerResolverConfig, PeerResolverHandle,
14};
15pub use self::query::DhtQueryMode;
16use self::query::{Query, QueryCache, StoreValue};
17use self::routing::HandlesRoutingTable;
18use self::storage::Storage;
19pub use self::storage::{DhtValueMerger, DhtValueSource, StorageError};
20use crate::network::Network;
21use crate::proto::dht::{
22    rpc, NodeInfoResponse, NodeResponse, PeerValue, PeerValueKey, PeerValueKeyName,
23    PeerValueKeyRef, PeerValueRef, Value, ValueRef, ValueResponseRaw,
24};
25use crate::types::{PeerId, PeerInfo, Request, Response, Service, ServiceRequest};
26use crate::util::{NetworkExt, Routable};
27
28mod background_tasks;
29mod config;
30mod peer_resolver;
31mod query;
32mod routing;
33mod storage;
34
35// Counters
36const METRIC_IN_REQ_TOTAL: &str = "tycho_net_dht_in_req_total";
37const METRIC_IN_REQ_FAIL_TOTAL: &str = "tycho_net_dht_in_req_fail_total";
38
39const METRIC_IN_REQ_WITH_PEER_INFO_TOTAL: &str = "tycho_net_dht_in_req_with_peer_info_total";
40const METRIC_IN_REQ_FIND_NODE_TOTAL: &str = "tycho_net_dht_in_req_find_node_total";
41const METRIC_IN_REQ_FIND_VALUE_TOTAL: &str = "tycho_net_dht_in_req_find_value_total";
42const METRIC_IN_REQ_GET_NODE_INFO_TOTAL: &str = "tycho_net_dht_in_req_get_node_info_total";
43const METRIC_IN_REQ_STORE_TOTAL: &str = "tycho_net_dht_in_req_store_value_total";
44
45#[derive(Clone)]
46pub struct DhtClient {
47    inner: Arc<DhtInner>,
48    network: Network,
49}
50
51impl DhtClient {
52    #[inline]
53    pub fn network(&self) -> &Network {
54        &self.network
55    }
56
57    #[inline]
58    pub fn service(&self) -> &DhtService {
59        DhtService::wrap(&self.inner)
60    }
61
62    pub fn add_peer(&self, peer: Arc<PeerInfo>) -> Result<bool> {
63        anyhow::ensure!(peer.verify(now_sec()), "invalid peer info");
64        let added = self.inner.add_peer_info(&self.network, peer);
65        Ok(added)
66    }
67
68    pub async fn get_node_info(&self, peer_id: &PeerId) -> Result<PeerInfo> {
69        let res = self
70            .network
71            .query(peer_id, Request::from_tl(rpc::GetNodeInfo))
72            .await?;
73        let NodeInfoResponse { info } = res.parse_tl()?;
74        Ok(info)
75    }
76
77    pub fn entry(&self, name: PeerValueKeyName) -> DhtQueryBuilder<'_> {
78        DhtQueryBuilder {
79            inner: &self.inner,
80            network: &self.network,
81            name,
82            idx: 0,
83        }
84    }
85
86    /// Find a value by its key hash.
87    ///
88    /// This is quite a low-level method, so it is recommended to use [`DhtClient::entry`].
89    pub async fn find_value(&self, key_hash: &[u8; 32], mode: DhtQueryMode) -> Option<Box<Value>> {
90        self.inner.find_value(&self.network, key_hash, mode).await
91    }
92}
93
94#[derive(Clone, Copy)]
95pub struct DhtQueryBuilder<'a> {
96    inner: &'a DhtInner,
97    network: &'a Network,
98    name: PeerValueKeyName,
99    idx: u32,
100}
101
102impl<'a> DhtQueryBuilder<'a> {
103    #[inline]
104    pub fn with_idx(&mut self, idx: u32) -> &mut Self {
105        self.idx = idx;
106        self
107    }
108
109    pub async fn find_value<T>(&self, peer_id: &PeerId) -> Result<T, FindValueError>
110    where
111        for<'tl> T: tl_proto::TlRead<'tl>,
112    {
113        let key_hash = tl_proto::hash(PeerValueKeyRef {
114            name: self.name,
115            peer_id,
116        });
117
118        match self
119            .inner
120            .find_value(self.network, &key_hash, DhtQueryMode::Closest)
121            .await
122        {
123            Some(value) => match value.as_ref() {
124                Value::Peer(value) => {
125                    tl_proto::deserialize(&value.data).map_err(FindValueError::InvalidData)
126                }
127                Value::Merged(_) => Err(FindValueError::InvalidData(
128                    tl_proto::TlError::UnknownConstructor,
129                )),
130            },
131            None => Err(FindValueError::NotFound),
132        }
133    }
134
135    pub async fn find_peer_value_raw(
136        &self,
137        peer_id: &PeerId,
138    ) -> Result<Box<PeerValue>, FindValueError> {
139        let key_hash = tl_proto::hash(PeerValueKeyRef {
140            name: self.name,
141            peer_id,
142        });
143
144        match self
145            .inner
146            .find_value(self.network, &key_hash, DhtQueryMode::Closest)
147            .await
148        {
149            Some(value) => {
150                realloc_box_enum!(value, {
151                    Value::Peer(value) => Box::new(value) => Ok(value),
152                    Value::Merged(_) => Err(FindValueError::InvalidData(
153                        tl_proto::TlError::UnknownConstructor,
154                    )),
155                })
156            }
157            None => Err(FindValueError::NotFound),
158        }
159    }
160
161    pub fn with_data<T>(&self, data: T) -> DhtQueryWithDataBuilder<'a>
162    where
163        T: tl_proto::TlWrite,
164    {
165        DhtQueryWithDataBuilder {
166            inner: *self,
167            data: tl_proto::serialize(&data),
168            at: None,
169            ttl: self.inner.config.max_stored_value_ttl.as_secs() as _,
170            with_peer_info: false,
171        }
172    }
173}
174
175pub struct DhtQueryWithDataBuilder<'a> {
176    inner: DhtQueryBuilder<'a>,
177    data: Vec<u8>,
178    at: Option<u32>,
179    ttl: u32,
180    with_peer_info: bool,
181}
182
183impl DhtQueryWithDataBuilder<'_> {
184    pub fn with_time(&mut self, at: u32) -> &mut Self {
185        self.at = Some(at);
186        self
187    }
188
189    pub fn with_ttl(&mut self, ttl: u32) -> &mut Self {
190        self.ttl = ttl;
191        self
192    }
193
194    pub fn with_peer_info(&mut self, with_peer_info: bool) -> &mut Self {
195        self.with_peer_info = with_peer_info;
196        self
197    }
198
199    pub async fn store(&self) -> Result<()> {
200        let dht = self.inner.inner;
201        let network = self.inner.network;
202
203        let mut value = self.make_unsigned_value_ref();
204        let signature = network.sign_tl(&value);
205        value.signature = &signature;
206
207        dht.store_value(network, &ValueRef::Peer(value), self.with_peer_info)
208            .await
209    }
210
211    pub fn store_locally(&self) -> Result<bool, StorageError> {
212        let dht = self.inner.inner;
213        let network = self.inner.network;
214
215        let mut value = self.make_unsigned_value_ref();
216        let signature = network.sign_tl(&value);
217        value.signature = &signature;
218
219        dht.store_value_locally(&ValueRef::Peer(value))
220    }
221
222    pub fn into_signed_value(self) -> PeerValue {
223        let dht = self.inner.inner;
224        let network = self.inner.network;
225
226        let mut value = PeerValue {
227            key: PeerValueKey {
228                name: self.name,
229                peer_id: dht.local_id,
230            },
231            data: self.data.into_boxed_slice(),
232            expires_at: self.at.unwrap_or_else(now_sec) + self.ttl,
233            signature: Box::new([0; 64]),
234        };
235        *value.signature = network.sign_tl(&value);
236        value
237    }
238
239    fn make_unsigned_value_ref(&self) -> PeerValueRef<'_> {
240        PeerValueRef {
241            key: PeerValueKeyRef {
242                name: self.inner.name,
243                peer_id: &self.inner.inner.local_id,
244            },
245            data: &self.data,
246            expires_at: self.at.unwrap_or_else(now_sec) + self.ttl,
247            signature: &[0; 64],
248        }
249    }
250}
251
252impl<'a> std::ops::Deref for DhtQueryWithDataBuilder<'a> {
253    type Target = DhtQueryBuilder<'a>;
254
255    #[inline]
256    fn deref(&self) -> &Self::Target {
257        &self.inner
258    }
259}
260
261impl<'a> std::ops::DerefMut for DhtQueryWithDataBuilder<'a> {
262    #[inline]
263    fn deref_mut(&mut self) -> &mut Self::Target {
264        &mut self.inner
265    }
266}
267
268pub struct DhtServiceBackgroundTasks {
269    inner: Arc<DhtInner>,
270}
271
272impl DhtServiceBackgroundTasks {
273    pub fn spawn(self, network: &Network) {
274        self.inner
275            .start_background_tasks(Network::downgrade(network));
276    }
277}
278
279pub struct DhtServiceBuilder {
280    local_id: PeerId,
281    config: Option<DhtConfig>,
282}
283
284impl DhtServiceBuilder {
285    pub fn with_config(mut self, config: DhtConfig) -> Self {
286        self.config = Some(config);
287        self
288    }
289
290    pub fn build(self) -> (DhtServiceBackgroundTasks, DhtService) {
291        let config = self.config.unwrap_or_default();
292
293        let storage = {
294            let mut builder = Storage::builder()
295                .with_max_capacity(config.max_storage_capacity)
296                .with_max_ttl(config.max_stored_value_ttl);
297
298            if let Some(time_to_idle) = config.storage_item_time_to_idle {
299                builder = builder.with_max_idle(time_to_idle);
300            }
301
302            builder.build()
303        };
304
305        let (announced_peers, _) = broadcast::channel(config.announced_peers_channel_capacity);
306
307        let inner = Arc::new(DhtInner {
308            local_id: self.local_id,
309            routing_table: Mutex::new(HandlesRoutingTable::new(self.local_id)),
310            storage,
311            local_peer_info: Mutex::new(None),
312            config,
313            announced_peers,
314            find_value_queries: Default::default(),
315            peer_added: Arc::new(Default::default()),
316        });
317
318        let background_tasks = DhtServiceBackgroundTasks {
319            inner: inner.clone(),
320        };
321
322        (background_tasks, DhtService(inner))
323    }
324}
325
326#[derive(Clone)]
327#[repr(transparent)]
328pub struct DhtService(Arc<DhtInner>);
329
330impl DhtService {
331    #[inline]
332    fn wrap(inner: &Arc<DhtInner>) -> &Self {
333        // SAFETY: `DhtService` has the same memory layout as `Arc<DhtInner>`.
334        unsafe { &*(inner as *const Arc<DhtInner>).cast::<Self>() }
335    }
336
337    pub fn builder(local_id: PeerId) -> DhtServiceBuilder {
338        DhtServiceBuilder {
339            local_id,
340            config: None,
341        }
342    }
343
344    pub fn local_id(&self) -> &PeerId {
345        &self.0.local_id
346    }
347
348    pub fn make_client(&self, network: &Network) -> DhtClient {
349        DhtClient {
350            inner: self.0.clone(),
351            network: network.clone(),
352        }
353    }
354
355    pub fn make_peer_resolver(&self) -> PeerResolverBuilder {
356        PeerResolver::builder(self.clone())
357    }
358
359    pub fn has_peer(&self, peer_id: &PeerId) -> bool {
360        self.0.routing_table.lock().unwrap().contains(peer_id)
361    }
362
363    pub fn store_value_locally(&self, value: &ValueRef<'_>) -> Result<bool, StorageError> {
364        self.0.store_value_locally(value)
365    }
366
367    pub fn insert_merger(
368        &self,
369        group_id: &[u8; 32],
370        merger: Arc<dyn DhtValueMerger>,
371    ) -> Option<Arc<dyn DhtValueMerger>> {
372        self.0.storage.insert_merger(group_id, merger)
373    }
374
375    pub fn remove_merger(&self, group_id: &[u8; 32]) -> Option<Arc<dyn DhtValueMerger>> {
376        self.0.storage.remove_merger(group_id)
377    }
378
379    pub fn peer_added(&self) -> &Arc<Notify> {
380        &self.0.peer_added
381    }
382}
383
384impl Service<ServiceRequest> for DhtService {
385    type QueryResponse = Response;
386    type OnQueryFuture = futures_util::future::Ready<Option<Self::QueryResponse>>;
387    type OnMessageFuture = futures_util::future::Ready<()>;
388    type OnDatagramFuture = futures_util::future::Ready<()>;
389
390    #[tracing::instrument(
391        level = "debug",
392        name = "on_dht_query",
393        skip_all,
394        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
395    )]
396    fn on_query(&self, req: ServiceRequest) -> Self::OnQueryFuture {
397        metrics::counter!(METRIC_IN_REQ_TOTAL).increment(1);
398
399        let (constructor, body) = match self.0.try_handle_prefix(&req) {
400            Ok(rest) => rest,
401            Err(e) => {
402                tracing::debug!("failed to deserialize query: {e}");
403                metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
404                return futures_util::future::ready(None);
405            }
406        };
407
408        let response = crate::match_tl_request!(body, tag = constructor, {
409            rpc::FindNode as ref r => {
410                tracing::debug!(key = %PeerId::wrap(&r.key), k = r.k, "find_node");
411                metrics::counter!(METRIC_IN_REQ_FIND_NODE_TOTAL).increment(1);
412
413                let res = self.0.handle_find_node(r);
414                Some(tl_proto::serialize(res))
415            },
416            rpc::FindValue as ref r => {
417                tracing::debug!(key = %PeerId::wrap(&r.key), k = r.k, "find_value");
418                metrics::counter!(METRIC_IN_REQ_FIND_VALUE_TOTAL).increment(1);
419
420                let res = self.0.handle_find_value(r);
421                Some(tl_proto::serialize(res))
422            },
423            rpc::GetNodeInfo as _ => {
424                tracing::debug!("get_node_info");
425                metrics::counter!(METRIC_IN_REQ_GET_NODE_INFO_TOTAL).increment(1);
426
427                self.0.handle_get_node_info().map(tl_proto::serialize)
428            },
429        }, e => {
430            tracing::debug!("failed to deserialize query: {e}");
431            None
432        });
433
434        if response.is_none() {
435            metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
436        }
437
438        futures_util::future::ready(response.map(|body| Response {
439            version: Default::default(),
440            body: Bytes::from(body),
441        }))
442    }
443
444    #[tracing::instrument(
445        level = "debug",
446        name = "on_dht_message",
447        skip_all,
448        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
449    )]
450    fn on_message(&self, req: ServiceRequest) -> Self::OnMessageFuture {
451        metrics::counter!(METRIC_IN_REQ_TOTAL).increment(1);
452
453        let (constructor, body) = match self.0.try_handle_prefix(&req) {
454            Ok(rest) => rest,
455            Err(e) => {
456                tracing::debug!("failed to deserialize message: {e}");
457                metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
458                return futures_util::future::ready(());
459            }
460        };
461
462        let mut has_error = false;
463        crate::match_tl_request!(body, tag = constructor, {
464            rpc::StoreRef<'_> as ref r => {
465                tracing::debug!("store");
466                metrics::counter!(METRIC_IN_REQ_STORE_TOTAL).increment(1);
467
468                if let Err(e) = self.0.handle_store(r) {
469                    tracing::debug!("failed to store value: {e}");
470                    has_error = true;
471                }
472            }
473        }, e => {
474            tracing::debug!("failed to deserialize message: {e}");
475            has_error = true;
476        });
477
478        if has_error {
479            metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
480        }
481
482        futures_util::future::ready(())
483    }
484
485    #[inline]
486    fn on_datagram(&self, _req: ServiceRequest) -> Self::OnDatagramFuture {
487        futures_util::future::ready(())
488    }
489}
490
491impl Routable for DhtService {
492    fn query_ids(&self) -> impl IntoIterator<Item = u32> {
493        [
494            rpc::WithPeerInfo::TL_ID,
495            rpc::FindNode::TL_ID,
496            rpc::FindValue::TL_ID,
497            rpc::GetNodeInfo::TL_ID,
498        ]
499    }
500
501    fn message_ids(&self) -> impl IntoIterator<Item = u32> {
502        [rpc::WithPeerInfo::TL_ID, rpc::Store::TL_ID]
503    }
504}
505
506struct DhtInner {
507    local_id: PeerId,
508    routing_table: Mutex<HandlesRoutingTable>,
509    storage: Storage,
510    local_peer_info: Mutex<Option<PeerInfo>>,
511    config: DhtConfig,
512    announced_peers: broadcast::Sender<Arc<PeerInfo>>,
513    find_value_queries: QueryCache<Option<Box<Value>>>,
514    peer_added: Arc<Notify>,
515}
516
517impl DhtInner {
518    async fn find_value(
519        &self,
520        network: &Network,
521        key_hash: &[u8; 32],
522        mode: DhtQueryMode,
523    ) -> Option<Box<Value>> {
524        self.find_value_queries
525            .run(key_hash, || {
526                let query = Query::new(
527                    network.clone(),
528                    &self.routing_table.lock().unwrap(),
529                    key_hash,
530                    self.config.max_k,
531                    mode,
532                );
533
534                // NOTE: expression is intentionally split to drop the routing table guard
535                Box::pin(query.find_value())
536            })
537            .await
538    }
539
540    async fn store_value(
541        &self,
542        network: &Network,
543        value: &ValueRef<'_>,
544        with_peer_info: bool,
545    ) -> Result<()> {
546        self.storage.insert(DhtValueSource::Local, value)?;
547
548        let local_peer_info = if with_peer_info {
549            let mut node_info = self.local_peer_info.lock().unwrap();
550            Some(
551                node_info
552                    .get_or_insert_with(|| self.make_local_peer_info(network, now_sec()))
553                    .clone(),
554            )
555        } else {
556            None
557        };
558
559        let query = StoreValue::new(
560            network.clone(),
561            &self.routing_table.lock().unwrap(),
562            value,
563            self.config.max_k,
564            local_peer_info.as_ref(),
565        );
566
567        // NOTE: expression is intentionally split to drop the routing table guard
568        query.run().await;
569        Ok(())
570    }
571
572    fn store_value_locally(&self, value: &ValueRef<'_>) -> Result<bool, StorageError> {
573        self.storage.insert(DhtValueSource::Local, value)
574    }
575
576    // NOTE: Requires the incoming peer info to be valid.
577    fn add_peer_info(&self, network: &Network, peer_info: Arc<PeerInfo>) -> bool {
578        if peer_info.id == self.local_id {
579            return false;
580        }
581
582        let mut routing_table = self.routing_table.lock().unwrap();
583        let added = routing_table.add(
584            peer_info.clone(),
585            self.config.max_k,
586            &self.config.max_peer_info_ttl,
587            |peer_info| network.known_peers().insert(peer_info, false).ok(),
588        );
589
590        if added {
591            self.peer_added.notify_waiters();
592        }
593
594        added
595    }
596
597    fn make_unsigned_peer_value<'a>(
598        &'a self,
599        name: PeerValueKeyName,
600        data: &'a [u8],
601        expires_at: u32,
602    ) -> PeerValueRef<'a> {
603        PeerValueRef {
604            key: PeerValueKeyRef {
605                name,
606                peer_id: &self.local_id,
607            },
608            data,
609            expires_at,
610            signature: &[0; 64],
611        }
612    }
613
614    fn make_local_peer_info(&self, network: &Network, now: u32) -> PeerInfo {
615        network.sign_peer_info(now, self.config.max_peer_info_ttl.as_secs() as u32)
616    }
617
618    fn try_handle_prefix<'a>(&self, req: &'a ServiceRequest) -> Result<(u32, &'a [u8])> {
619        let mut body = req.as_ref();
620        anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof);
621
622        // NOTE: read constructor without advancing the body
623        let mut constructor = std::convert::identity(body).get_u32_le();
624        let mut offset = 0;
625
626        if constructor == rpc::WithPeerInfo::TL_ID {
627            metrics::counter!(METRIC_IN_REQ_WITH_PEER_INFO_TOTAL).increment(1);
628
629            let peer_info = rpc::WithPeerInfo::read_from(body, &mut offset)?.peer_info;
630            anyhow::ensure!(
631                peer_info.id == req.metadata.peer_id,
632                "suggested peer ID does not belong to the sender"
633            );
634            self.announced_peers.send(Arc::new(peer_info)).ok();
635
636            body = &body[offset..];
637            anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof);
638
639            // NOTE: read constructor without advancing the body
640            constructor = std::convert::identity(body).get_u32_le();
641        }
642
643        Ok((constructor, body))
644    }
645
646    fn handle_store(&self, req: &rpc::StoreRef<'_>) -> Result<bool, StorageError> {
647        self.storage.insert(DhtValueSource::Remote, &req.value)
648    }
649
650    fn handle_find_node(&self, req: &rpc::FindNode) -> NodeResponse {
651        let nodes = self
652            .routing_table
653            .lock()
654            .unwrap()
655            .closest(&req.key, (req.k as usize).min(self.config.max_k));
656
657        NodeResponse { nodes }
658    }
659
660    fn handle_find_value(&self, req: &rpc::FindValue) -> ValueResponseRaw {
661        if let Some(value) = self.storage.get(&req.key) {
662            ValueResponseRaw::Found(value)
663        } else {
664            let nodes = self
665                .routing_table
666                .lock()
667                .unwrap()
668                .closest(&req.key, (req.k as usize).min(self.config.max_k));
669
670            ValueResponseRaw::NotFound(nodes)
671        }
672    }
673
674    fn handle_get_node_info(&self) -> Option<NodeInfoResponse> {
675        self.local_peer_info
676            .lock()
677            .unwrap()
678            .clone()
679            .map(|info| NodeInfoResponse { info })
680    }
681}
682
683fn random_key_at_distance(from: &PeerId, distance: usize, rng: &mut impl RngCore) -> PeerId {
684    let distance = MAX_XOR_DISTANCE - distance;
685
686    let mut result = *from;
687
688    let byte_offset = distance / 8;
689    rng.fill_bytes(&mut result.0[byte_offset..]);
690
691    let bit_offset = distance % 8;
692    if bit_offset != 0 {
693        let mask = 0xff >> bit_offset;
694        result.0[byte_offset] ^= (result.0[byte_offset] ^ from.0[byte_offset]) & !mask;
695    }
696
697    result
698}
699
700pub fn xor_distance(left: &PeerId, right: &PeerId) -> usize {
701    for (i, (left, right)) in std::iter::zip(left.0.chunks(8), right.0.chunks(8)).enumerate() {
702        let left = u64::from_be_bytes(left.try_into().unwrap());
703        let right = u64::from_be_bytes(right.try_into().unwrap());
704        let diff = left ^ right;
705        if diff != 0 {
706            return MAX_XOR_DISTANCE - (i * 64 + diff.leading_zeros() as usize);
707        }
708    }
709
710    0
711}
712
713const MAX_XOR_DISTANCE: usize = 256;
714
715#[derive(Debug, thiserror::Error)]
716pub enum FindValueError {
717    #[error("failed to deserialize value: {0}")]
718    InvalidData(#[from] tl_proto::TlError),
719    #[error("value not found")]
720    NotFound,
721}
722
723#[cfg(test)]
724mod tests {
725    use super::*;
726
727    #[test]
728    fn proper_random_keys() {
729        let peer_id = rand::random();
730        let random_id = random_key_at_distance(&peer_id, 20, &mut rand::thread_rng());
731        println!("{peer_id}");
732        println!("{random_id}");
733
734        let distance = xor_distance(&peer_id, &random_id);
735        println!("{distance}");
736        assert!(distance <= 23);
737    }
738}