tycho_network/dht/
mod.rs

1use std::sync::{Arc, Mutex};
2
3use anyhow::Result;
4use bytes::{Buf, Bytes};
5use rand::RngCore;
6use tl_proto::TlRead;
7use tokio::sync::{Notify, broadcast};
8use tycho_util::realloc_box_enum;
9use tycho_util::time::now_sec;
10
11pub use self::config::DhtConfig;
12pub use self::peer_resolver::{
13    PeerResolver, PeerResolverBuilder, PeerResolverConfig, PeerResolverHandle,
14};
15pub use self::query::DhtQueryMode;
16use self::query::{Query, QueryCache, StoreValue};
17use self::routing::HandlesRoutingTable;
18use self::storage::Storage;
19pub use self::storage::{DhtValueMerger, DhtValueSource, StorageError};
20use crate::network::Network;
21use crate::proto::dht::{
22    NodeInfoResponse, NodeResponse, PeerValue, PeerValueKey, PeerValueKeyName, PeerValueKeyRef,
23    PeerValueRef, Value, ValueRef, ValueResponseRaw, rpc,
24};
25use crate::types::{PeerId, PeerInfo, Request, Response, Service, ServiceRequest};
26use crate::util::{NetworkExt, Routable};
27
28mod background_tasks;
29mod config;
30mod peer_resolver;
31mod query;
32mod routing;
33mod storage;
34
35// Counters
36const METRIC_IN_REQ_TOTAL: &str = "tycho_net_dht_in_req_total";
37const METRIC_IN_REQ_FAIL_TOTAL: &str = "tycho_net_dht_in_req_fail_total";
38
39const METRIC_IN_REQ_WITH_PEER_INFO_TOTAL: &str = "tycho_net_dht_in_req_with_peer_info_total";
40const METRIC_IN_REQ_FIND_NODE_TOTAL: &str = "tycho_net_dht_in_req_find_node_total";
41const METRIC_IN_REQ_FIND_VALUE_TOTAL: &str = "tycho_net_dht_in_req_find_value_total";
42const METRIC_IN_REQ_GET_NODE_INFO_TOTAL: &str = "tycho_net_dht_in_req_get_node_info_total";
43const METRIC_IN_REQ_STORE_TOTAL: &str = "tycho_net_dht_in_req_store_value_total";
44
45#[derive(Clone)]
46pub struct DhtClient {
47    inner: Arc<DhtInner>,
48    network: Network,
49}
50
51impl DhtClient {
52    #[inline]
53    pub fn network(&self) -> &Network {
54        &self.network
55    }
56
57    #[inline]
58    pub fn service(&self) -> &DhtService {
59        DhtService::wrap(&self.inner)
60    }
61
62    pub fn add_peer(&self, peer: Arc<PeerInfo>) -> Result<bool> {
63        anyhow::ensure!(peer.verify(now_sec()), "invalid peer info");
64        let added = self.inner.add_peer_info(&self.network, peer);
65        Ok(added)
66    }
67
68    pub async fn get_node_info(&self, peer_id: &PeerId) -> Result<PeerInfo> {
69        let res = self
70            .network
71            .query(peer_id, Request::from_tl(rpc::GetNodeInfo))
72            .await?;
73        let NodeInfoResponse { info } = res.parse_tl()?;
74        Ok(info)
75    }
76
77    pub fn entry(&self, name: PeerValueKeyName) -> DhtQueryBuilder<'_> {
78        DhtQueryBuilder {
79            inner: &self.inner,
80            network: &self.network,
81            name,
82            idx: 0,
83        }
84    }
85
86    /// Find a value by its key hash.
87    ///
88    /// This is quite a low-level method, so it is recommended to use [`DhtClient::entry`].
89    pub async fn find_value(&self, key_hash: &[u8; 32], mode: DhtQueryMode) -> Option<Box<Value>> {
90        self.inner.find_value(&self.network, key_hash, mode).await
91    }
92}
93
94#[derive(Clone, Copy)]
95pub struct DhtQueryBuilder<'a> {
96    inner: &'a DhtInner,
97    network: &'a Network,
98    name: PeerValueKeyName,
99    idx: u32,
100}
101
102impl<'a> DhtQueryBuilder<'a> {
103    #[inline]
104    pub fn with_idx(&mut self, idx: u32) -> &mut Self {
105        self.idx = idx;
106        self
107    }
108
109    pub async fn find_value<T>(&self, peer_id: &PeerId) -> Result<T, FindValueError>
110    where
111        for<'tl> T: tl_proto::TlRead<'tl>,
112    {
113        let key_hash = tl_proto::hash(PeerValueKeyRef {
114            name: self.name,
115            peer_id,
116        });
117
118        match self
119            .inner
120            .find_value(self.network, &key_hash, DhtQueryMode::Closest)
121            .await
122        {
123            Some(value) => match value.as_ref() {
124                Value::Peer(value) => {
125                    tl_proto::deserialize(&value.data).map_err(FindValueError::InvalidData)
126                }
127                Value::Merged(_) => Err(FindValueError::InvalidData(
128                    tl_proto::TlError::UnknownConstructor,
129                )),
130            },
131            None => Err(FindValueError::NotFound),
132        }
133    }
134
135    pub async fn find_peer_value_raw(
136        &self,
137        peer_id: &PeerId,
138    ) -> Result<Box<PeerValue>, FindValueError> {
139        let key_hash = tl_proto::hash(PeerValueKeyRef {
140            name: self.name,
141            peer_id,
142        });
143
144        match self
145            .inner
146            .find_value(self.network, &key_hash, DhtQueryMode::Closest)
147            .await
148        {
149            Some(value) => {
150                realloc_box_enum!(value, {
151                    Value::Peer(value) => Box::new(value) => Ok(value),
152                    Value::Merged(_) => Err(FindValueError::InvalidData(
153                        tl_proto::TlError::UnknownConstructor,
154                    )),
155                })
156            }
157            None => Err(FindValueError::NotFound),
158        }
159    }
160
161    pub fn with_data<T>(&self, data: T) -> DhtQueryWithDataBuilder<'a>
162    where
163        T: tl_proto::TlWrite,
164    {
165        DhtQueryWithDataBuilder {
166            inner: *self,
167            data: tl_proto::serialize(&data),
168            at: None,
169            ttl: self.inner.config.max_stored_value_ttl.as_secs() as _,
170            with_peer_info: false,
171        }
172    }
173}
174
175pub struct DhtQueryWithDataBuilder<'a> {
176    inner: DhtQueryBuilder<'a>,
177    data: Vec<u8>,
178    at: Option<u32>,
179    ttl: u32,
180    with_peer_info: bool,
181}
182
183impl DhtQueryWithDataBuilder<'_> {
184    pub fn with_time(&mut self, at: u32) -> &mut Self {
185        self.at = Some(at);
186        self
187    }
188
189    pub fn with_ttl(&mut self, ttl: u32) -> &mut Self {
190        self.ttl = ttl;
191        self
192    }
193
194    pub fn with_peer_info(&mut self, with_peer_info: bool) -> &mut Self {
195        self.with_peer_info = with_peer_info;
196        self
197    }
198
199    pub async fn store(&self) -> Result<()> {
200        let dht = self.inner.inner;
201        let network = self.inner.network;
202
203        let mut value = self.make_unsigned_value_ref();
204        let signature = network.sign_tl(&value);
205        value.signature = &signature;
206
207        dht.store_value(network, &ValueRef::Peer(value), self.with_peer_info)
208            .await
209    }
210
211    pub fn store_locally(&self) -> Result<bool, StorageError> {
212        let dht = self.inner.inner;
213        let network = self.inner.network;
214
215        let mut value = self.make_unsigned_value_ref();
216        let signature = network.sign_tl(&value);
217        value.signature = &signature;
218
219        dht.store_value_locally(&ValueRef::Peer(value))
220    }
221
222    pub fn into_signed_value(self) -> PeerValue {
223        let dht = self.inner.inner;
224        let network = self.inner.network;
225
226        let mut value = PeerValue {
227            key: PeerValueKey {
228                name: self.name,
229                peer_id: dht.local_id,
230            },
231            data: self.data.into_boxed_slice(),
232            expires_at: self.at.unwrap_or_else(now_sec) + self.ttl,
233            signature: Box::new([0; 64]),
234        };
235        *value.signature = network.sign_tl(&value);
236        value
237    }
238
239    fn make_unsigned_value_ref(&self) -> PeerValueRef<'_> {
240        PeerValueRef {
241            key: PeerValueKeyRef {
242                name: self.inner.name,
243                peer_id: &self.inner.inner.local_id,
244            },
245            data: &self.data,
246            expires_at: self.at.unwrap_or_else(now_sec) + self.ttl,
247            signature: &[0; 64],
248        }
249    }
250}
251
252impl<'a> std::ops::Deref for DhtQueryWithDataBuilder<'a> {
253    type Target = DhtQueryBuilder<'a>;
254
255    #[inline]
256    fn deref(&self) -> &Self::Target {
257        &self.inner
258    }
259}
260
261impl std::ops::DerefMut for DhtQueryWithDataBuilder<'_> {
262    #[inline]
263    fn deref_mut(&mut self) -> &mut Self::Target {
264        &mut self.inner
265    }
266}
267
268pub struct DhtServiceBackgroundTasks {
269    inner: Arc<DhtInner>,
270}
271
272impl DhtServiceBackgroundTasks {
273    pub fn spawn(self, network: &Network) {
274        self.inner
275            .start_background_tasks(Network::downgrade(network));
276    }
277}
278
279pub struct DhtServiceBuilder {
280    local_id: PeerId,
281    config: Option<DhtConfig>,
282}
283
284impl DhtServiceBuilder {
285    pub fn with_config(mut self, config: DhtConfig) -> Self {
286        self.config = Some(config);
287        self
288    }
289
290    pub fn build(self) -> (DhtServiceBackgroundTasks, DhtService) {
291        let config = self.config.unwrap_or_default();
292
293        let storage = {
294            let mut builder = Storage::builder()
295                .with_max_capacity(config.max_storage_capacity)
296                .with_max_ttl(config.max_stored_value_ttl);
297
298            if let Some(time_to_idle) = config.storage_item_time_to_idle {
299                builder = builder.with_max_idle(time_to_idle);
300            }
301
302            builder.build()
303        };
304
305        let (announced_peers, _) = broadcast::channel(config.announced_peers_channel_capacity);
306
307        let inner = Arc::new(DhtInner {
308            local_id: self.local_id,
309            routing_table: Mutex::new(HandlesRoutingTable::new(self.local_id)),
310            storage,
311            local_peer_info: Mutex::new(None),
312            config,
313            announced_peers,
314            find_value_queries: Default::default(),
315            peer_added: Arc::new(Default::default()),
316        });
317
318        let background_tasks = DhtServiceBackgroundTasks {
319            inner: inner.clone(),
320        };
321
322        (background_tasks, DhtService(inner))
323    }
324}
325
326#[derive(Clone)]
327#[repr(transparent)]
328pub struct DhtService(Arc<DhtInner>);
329
330impl DhtService {
331    #[inline]
332    fn wrap(inner: &Arc<DhtInner>) -> &Self {
333        // SAFETY: `DhtService` has the same memory layout as `Arc<DhtInner>`.
334        unsafe { &*(inner as *const Arc<DhtInner>).cast::<Self>() }
335    }
336
337    pub fn builder(local_id: PeerId) -> DhtServiceBuilder {
338        DhtServiceBuilder {
339            local_id,
340            config: None,
341        }
342    }
343
344    pub fn local_id(&self) -> &PeerId {
345        &self.0.local_id
346    }
347
348    pub fn make_client(&self, network: &Network) -> DhtClient {
349        DhtClient {
350            inner: self.0.clone(),
351            network: network.clone(),
352        }
353    }
354
355    pub fn make_peer_resolver(&self) -> PeerResolverBuilder {
356        PeerResolver::builder(self.clone())
357    }
358
359    pub fn has_peer(&self, peer_id: &PeerId) -> bool {
360        self.0.routing_table.lock().unwrap().contains(peer_id)
361    }
362
363    pub fn find_local_closest(&self, key: &[u8; 32], count: usize) -> Vec<Arc<PeerInfo>> {
364        self.0.routing_table.lock().unwrap().closest(key, count)
365    }
366
367    pub fn store_value_locally(&self, value: &ValueRef<'_>) -> Result<bool, StorageError> {
368        self.0.store_value_locally(value)
369    }
370
371    pub fn insert_merger(
372        &self,
373        group_id: &[u8; 32],
374        merger: Arc<dyn DhtValueMerger>,
375    ) -> Option<Arc<dyn DhtValueMerger>> {
376        self.0.storage.insert_merger(group_id, merger)
377    }
378
379    pub fn remove_merger(&self, group_id: &[u8; 32]) -> Option<Arc<dyn DhtValueMerger>> {
380        self.0.storage.remove_merger(group_id)
381    }
382
383    pub fn peer_added(&self) -> &Arc<Notify> {
384        &self.0.peer_added
385    }
386}
387
388impl Service<ServiceRequest> for DhtService {
389    type QueryResponse = Response;
390    type OnQueryFuture = futures_util::future::Ready<Option<Self::QueryResponse>>;
391    type OnMessageFuture = futures_util::future::Ready<()>;
392
393    #[tracing::instrument(
394        level = "debug",
395        name = "on_dht_query",
396        skip_all,
397        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
398    )]
399    fn on_query(&self, req: ServiceRequest) -> Self::OnQueryFuture {
400        metrics::counter!(METRIC_IN_REQ_TOTAL).increment(1);
401
402        let (constructor, body) = match self.0.try_handle_prefix(&req) {
403            Ok(rest) => rest,
404            Err(e) => {
405                tracing::debug!("failed to deserialize query: {e}");
406                metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
407                return futures_util::future::ready(None);
408            }
409        };
410
411        let response = crate::match_tl_request!(body, tag = constructor, {
412            rpc::FindNode as ref r => {
413                tracing::debug!(key = %PeerId::wrap(&r.key), k = r.k, "find_node");
414                metrics::counter!(METRIC_IN_REQ_FIND_NODE_TOTAL).increment(1);
415
416                let res = self.0.handle_find_node(r);
417                Some(tl_proto::serialize(res))
418            },
419            rpc::FindValue as ref r => {
420                tracing::debug!(key = %PeerId::wrap(&r.key), k = r.k, "find_value");
421                metrics::counter!(METRIC_IN_REQ_FIND_VALUE_TOTAL).increment(1);
422
423                let res = self.0.handle_find_value(r);
424                Some(tl_proto::serialize(res))
425            },
426            rpc::GetNodeInfo as _ => {
427                tracing::debug!("get_node_info");
428                metrics::counter!(METRIC_IN_REQ_GET_NODE_INFO_TOTAL).increment(1);
429
430                self.0.handle_get_node_info().map(tl_proto::serialize)
431            },
432        }, e => {
433            tracing::debug!("failed to deserialize query: {e}");
434            None
435        });
436
437        if response.is_none() {
438            metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
439        }
440
441        futures_util::future::ready(response.map(|body| Response {
442            version: Default::default(),
443            body: Bytes::from(body),
444        }))
445    }
446
447    #[tracing::instrument(
448        level = "debug",
449        name = "on_dht_message",
450        skip_all,
451        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
452    )]
453    fn on_message(&self, req: ServiceRequest) -> Self::OnMessageFuture {
454        metrics::counter!(METRIC_IN_REQ_TOTAL).increment(1);
455
456        let (constructor, body) = match self.0.try_handle_prefix(&req) {
457            Ok(rest) => rest,
458            Err(e) => {
459                tracing::debug!("failed to deserialize message: {e}");
460                metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
461                return futures_util::future::ready(());
462            }
463        };
464
465        let mut has_error = false;
466        crate::match_tl_request!(body, tag = constructor, {
467            rpc::StoreRef<'_> as ref r => {
468                tracing::debug!("store");
469                metrics::counter!(METRIC_IN_REQ_STORE_TOTAL).increment(1);
470
471                if let Err(e) = self.0.handle_store(r) {
472                    tracing::debug!("failed to store value: {e}");
473                    has_error = true;
474                }
475            }
476        }, e => {
477            tracing::debug!("failed to deserialize message: {e}");
478            has_error = true;
479        });
480
481        if has_error {
482            metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
483        }
484
485        futures_util::future::ready(())
486    }
487}
488
489impl Routable for DhtService {
490    fn query_ids(&self) -> impl IntoIterator<Item = u32> {
491        [
492            rpc::WithPeerInfo::TL_ID,
493            rpc::FindNode::TL_ID,
494            rpc::FindValue::TL_ID,
495            rpc::GetNodeInfo::TL_ID,
496        ]
497    }
498
499    fn message_ids(&self) -> impl IntoIterator<Item = u32> {
500        [rpc::WithPeerInfo::TL_ID, rpc::Store::TL_ID]
501    }
502}
503
504struct DhtInner {
505    local_id: PeerId,
506    routing_table: Mutex<HandlesRoutingTable>,
507    storage: Storage,
508    local_peer_info: Mutex<Option<PeerInfo>>,
509    config: DhtConfig,
510    announced_peers: broadcast::Sender<Arc<PeerInfo>>,
511    find_value_queries: QueryCache<Option<Box<Value>>>,
512    peer_added: Arc<Notify>,
513}
514
515impl DhtInner {
516    async fn find_value(
517        &self,
518        network: &Network,
519        key_hash: &[u8; 32],
520        mode: DhtQueryMode,
521    ) -> Option<Box<Value>> {
522        self.find_value_queries
523            .run(key_hash, || {
524                let query = Query::new(
525                    network.clone(),
526                    &self.routing_table.lock().unwrap(),
527                    key_hash,
528                    self.config.max_k,
529                    mode,
530                );
531
532                // NOTE: expression is intentionally split to drop the routing table guard
533                Box::pin(query.find_value())
534            })
535            .await
536    }
537
538    async fn store_value(
539        &self,
540        network: &Network,
541        value: &ValueRef<'_>,
542        with_peer_info: bool,
543    ) -> Result<()> {
544        self.storage.insert(DhtValueSource::Local, value)?;
545
546        let local_peer_info = if with_peer_info {
547            let mut node_info = self.local_peer_info.lock().unwrap();
548            Some(
549                node_info
550                    .get_or_insert_with(|| self.make_local_peer_info(network, now_sec()))
551                    .clone(),
552            )
553        } else {
554            None
555        };
556
557        let query = StoreValue::new(
558            network.clone(),
559            &self.routing_table.lock().unwrap(),
560            value,
561            self.config.max_k,
562            local_peer_info.as_ref(),
563        );
564
565        // NOTE: expression is intentionally split to drop the routing table guard
566        query.run().await;
567        Ok(())
568    }
569
570    fn store_value_locally(&self, value: &ValueRef<'_>) -> Result<bool, StorageError> {
571        self.storage.insert(DhtValueSource::Local, value)
572    }
573
574    // NOTE: Requires the incoming peer info to be valid.
575    fn add_peer_info(&self, network: &Network, peer_info: Arc<PeerInfo>) -> bool {
576        if peer_info.id == self.local_id {
577            return false;
578        }
579
580        let mut routing_table = self.routing_table.lock().unwrap();
581        let added = routing_table.add(
582            peer_info.clone(),
583            self.config.max_k,
584            &self.config.max_peer_info_ttl,
585            |peer_info| network.known_peers().insert(peer_info, false).ok(),
586        );
587
588        if added {
589            self.peer_added.notify_waiters();
590        }
591
592        added
593    }
594
595    fn make_unsigned_peer_value<'a>(
596        &'a self,
597        name: PeerValueKeyName,
598        data: &'a [u8],
599        expires_at: u32,
600    ) -> PeerValueRef<'a> {
601        PeerValueRef {
602            key: PeerValueKeyRef {
603                name,
604                peer_id: &self.local_id,
605            },
606            data,
607            expires_at,
608            signature: &[0; 64],
609        }
610    }
611
612    fn make_local_peer_info(&self, network: &Network, now: u32) -> PeerInfo {
613        network.sign_peer_info(now, self.config.max_peer_info_ttl.as_secs() as u32)
614    }
615
616    fn try_handle_prefix<'a>(&self, req: &'a ServiceRequest) -> Result<(u32, &'a [u8])> {
617        let mut body = req.as_ref();
618        anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof);
619
620        // NOTE: read constructor without advancing the body
621        let mut constructor = std::convert::identity(body).get_u32_le();
622
623        if constructor == rpc::WithPeerInfo::TL_ID {
624            metrics::counter!(METRIC_IN_REQ_WITH_PEER_INFO_TOTAL).increment(1);
625
626            let peer_info = rpc::WithPeerInfo::read_from(&mut body)?.peer_info;
627            anyhow::ensure!(
628                peer_info.id == req.metadata.peer_id,
629                "suggested peer ID does not belong to the sender"
630            );
631
632            anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof);
633            self.announced_peers.send(Arc::new(peer_info)).ok();
634
635            // NOTE: read constructor without advancing the body
636            constructor = std::convert::identity(body).get_u32_le();
637        }
638
639        Ok((constructor, body))
640    }
641
642    fn handle_store(&self, req: &rpc::StoreRef<'_>) -> Result<bool, StorageError> {
643        self.storage.insert(DhtValueSource::Remote, &req.value)
644    }
645
646    fn handle_find_node(&self, req: &rpc::FindNode) -> NodeResponse {
647        let nodes = self
648            .routing_table
649            .lock()
650            .unwrap()
651            .closest(&req.key, (req.k as usize).min(self.config.max_k));
652
653        NodeResponse { nodes }
654    }
655
656    fn handle_find_value(&self, req: &rpc::FindValue) -> ValueResponseRaw {
657        if let Some(value) = self.storage.get(&req.key) {
658            ValueResponseRaw::Found(value)
659        } else {
660            let nodes = self
661                .routing_table
662                .lock()
663                .unwrap()
664                .closest(&req.key, (req.k as usize).min(self.config.max_k));
665
666            ValueResponseRaw::NotFound(nodes)
667        }
668    }
669
670    fn handle_get_node_info(&self) -> Option<NodeInfoResponse> {
671        self.local_peer_info
672            .lock()
673            .unwrap()
674            .clone()
675            .map(|info| NodeInfoResponse { info })
676    }
677}
678
679fn random_key_at_distance(from: &PeerId, distance: usize, rng: &mut impl RngCore) -> PeerId {
680    let distance = MAX_XOR_DISTANCE - distance;
681
682    let mut result = *from;
683
684    let byte_offset = distance / 8;
685    rng.fill_bytes(&mut result.0[byte_offset..]);
686
687    let bit_offset = distance % 8;
688    if bit_offset != 0 {
689        let mask = 0xff >> bit_offset;
690        result.0[byte_offset] ^= (result.0[byte_offset] ^ from.0[byte_offset]) & !mask;
691    }
692
693    result
694}
695
696pub fn xor_distance(left: &PeerId, right: &PeerId) -> usize {
697    for (i, (left, right)) in std::iter::zip(left.0.chunks(8), right.0.chunks(8)).enumerate() {
698        let left = u64::from_be_bytes(left.try_into().unwrap());
699        let right = u64::from_be_bytes(right.try_into().unwrap());
700        let diff = left ^ right;
701        if diff != 0 {
702            return MAX_XOR_DISTANCE - (i * 64 + diff.leading_zeros() as usize);
703        }
704    }
705
706    0
707}
708
709const MAX_XOR_DISTANCE: usize = 256;
710
711#[derive(Debug, thiserror::Error)]
712pub enum FindValueError {
713    #[error("failed to deserialize value: {0}")]
714    InvalidData(#[from] tl_proto::TlError),
715    #[error("value not found")]
716    NotFound,
717}
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722
723    #[test]
724    fn proper_random_keys() {
725        let peer_id = rand::random();
726        let random_id = random_key_at_distance(&peer_id, 20, &mut rand::rng());
727        println!("{peer_id}");
728        println!("{random_id}");
729
730        let distance = xor_distance(&peer_id, &random_id);
731        println!("{distance}");
732        assert!(distance <= 23);
733    }
734}