tycho_network/dht/
mod.rs

1use std::sync::{Arc, Mutex};
2
3use anyhow::Result;
4use bytes::{Buf, Bytes};
5use rand::RngCore;
6use tl_proto::TlRead;
7use tokio::sync::{Notify, broadcast};
8use tycho_util::realloc_box_enum;
9use tycho_util::time::now_sec;
10
11pub use self::config::DhtConfig;
12pub use self::peer_resolver::{
13    PeerResolver, PeerResolverBuilder, PeerResolverConfig, PeerResolverHandle,
14};
15pub use self::query::DhtQueryMode;
16use self::query::{Query, QueryCache, StoreValue};
17use self::routing::HandlesRoutingTable;
18use self::storage::Storage;
19pub use self::storage::{DhtValueMerger, DhtValueSource, StorageError};
20use crate::network::Network;
21use crate::proto::dht::{
22    NodeInfoResponse, NodeResponse, PeerValue, PeerValueKey, PeerValueKeyName, PeerValueKeyRef,
23    PeerValueRef, Value, ValueRef, ValueResponseRaw, rpc,
24};
25use crate::types::{PeerId, PeerInfo, Request, Response, Service, ServiceRequest};
26use crate::util::{NetworkExt, Routable};
27
28mod background_tasks;
29mod config;
30mod peer_resolver;
31mod query;
32mod routing;
33mod storage;
34
35// Counters
36const METRIC_IN_REQ_TOTAL: &str = "tycho_net_dht_in_req_total";
37const METRIC_IN_REQ_FAIL_TOTAL: &str = "tycho_net_dht_in_req_fail_total";
38
39const METRIC_IN_REQ_WITH_PEER_INFO_TOTAL: &str = "tycho_net_dht_in_req_with_peer_info_total";
40const METRIC_IN_REQ_FIND_NODE_TOTAL: &str = "tycho_net_dht_in_req_find_node_total";
41const METRIC_IN_REQ_FIND_VALUE_TOTAL: &str = "tycho_net_dht_in_req_find_value_total";
42const METRIC_IN_REQ_GET_NODE_INFO_TOTAL: &str = "tycho_net_dht_in_req_get_node_info_total";
43const METRIC_IN_REQ_STORE_TOTAL: &str = "tycho_net_dht_in_req_store_value_total";
44
45#[derive(Clone)]
46pub struct DhtClient {
47    inner: Arc<DhtInner>,
48    network: Network,
49}
50
51impl DhtClient {
52    #[inline]
53    pub fn network(&self) -> &Network {
54        &self.network
55    }
56
57    #[inline]
58    pub fn service(&self) -> &DhtService {
59        DhtService::wrap(&self.inner)
60    }
61
62    pub fn add_peer(&self, peer: Arc<PeerInfo>) -> Result<bool> {
63        anyhow::ensure!(peer.verify(now_sec()), "invalid peer info");
64        let added = self.inner.add_peer_info(&self.network, peer);
65        Ok(added)
66    }
67
68    pub async fn get_node_info(&self, peer_id: &PeerId) -> Result<PeerInfo> {
69        let res = self
70            .network
71            .query(peer_id, Request::from_tl(rpc::GetNodeInfo))
72            .await?;
73        let NodeInfoResponse { info } = res.parse_tl()?;
74        Ok(info)
75    }
76
77    pub fn entry(&self, name: PeerValueKeyName) -> DhtQueryBuilder<'_> {
78        DhtQueryBuilder {
79            inner: &self.inner,
80            network: &self.network,
81            name,
82            idx: 0,
83        }
84    }
85
86    /// Find a value by its key hash.
87    ///
88    /// This is quite a low-level method, so it is recommended to use [`DhtClient::entry`].
89    pub async fn find_value(&self, key_hash: &[u8; 32], mode: DhtQueryMode) -> Option<Box<Value>> {
90        self.inner.find_value(&self.network, key_hash, mode).await
91    }
92}
93
94#[derive(Clone, Copy)]
95pub struct DhtQueryBuilder<'a> {
96    inner: &'a DhtInner,
97    network: &'a Network,
98    name: PeerValueKeyName,
99    idx: u32,
100}
101
102impl<'a> DhtQueryBuilder<'a> {
103    #[inline]
104    pub fn with_idx(&mut self, idx: u32) -> &mut Self {
105        self.idx = idx;
106        self
107    }
108
109    pub async fn find_value<T>(&self, peer_id: &PeerId) -> Result<T, FindValueError>
110    where
111        for<'tl> T: tl_proto::TlRead<'tl>,
112    {
113        let key_hash = tl_proto::hash(PeerValueKeyRef {
114            name: self.name,
115            peer_id,
116        });
117
118        match self
119            .inner
120            .find_value(self.network, &key_hash, DhtQueryMode::Closest)
121            .await
122        {
123            Some(value) => match value.as_ref() {
124                Value::Peer(value) => {
125                    tl_proto::deserialize(&value.data).map_err(FindValueError::InvalidData)
126                }
127                Value::Merged(_) => Err(FindValueError::InvalidData(
128                    tl_proto::TlError::UnknownConstructor,
129                )),
130            },
131            None => Err(FindValueError::NotFound),
132        }
133    }
134
135    pub async fn find_peer_value_raw(
136        &self,
137        peer_id: &PeerId,
138    ) -> Result<Box<PeerValue>, FindValueError> {
139        let key_hash = tl_proto::hash(PeerValueKeyRef {
140            name: self.name,
141            peer_id,
142        });
143
144        match self
145            .inner
146            .find_value(self.network, &key_hash, DhtQueryMode::Closest)
147            .await
148        {
149            Some(value) => {
150                realloc_box_enum!(value, {
151                    Value::Peer(value) => Box::new(value) => Ok(value),
152                    Value::Merged(_) => Err(FindValueError::InvalidData(
153                        tl_proto::TlError::UnknownConstructor,
154                    )),
155                })
156            }
157            None => Err(FindValueError::NotFound),
158        }
159    }
160
161    pub fn with_data<T>(&self, data: T) -> DhtQueryWithDataBuilder<'a>
162    where
163        T: tl_proto::TlWrite,
164    {
165        DhtQueryWithDataBuilder {
166            inner: *self,
167            data: tl_proto::serialize(&data),
168            at: None,
169            ttl: self.inner.config.max_stored_value_ttl.as_secs() as _,
170            with_peer_info: false,
171        }
172    }
173}
174
175pub struct DhtQueryWithDataBuilder<'a> {
176    inner: DhtQueryBuilder<'a>,
177    data: Vec<u8>,
178    at: Option<u32>,
179    ttl: u32,
180    with_peer_info: bool,
181}
182
183impl DhtQueryWithDataBuilder<'_> {
184    pub fn with_time(&mut self, at: u32) -> &mut Self {
185        self.at = Some(at);
186        self
187    }
188
189    pub fn with_ttl(&mut self, ttl: u32) -> &mut Self {
190        self.ttl = ttl;
191        self
192    }
193
194    pub fn with_peer_info(&mut self, with_peer_info: bool) -> &mut Self {
195        self.with_peer_info = with_peer_info;
196        self
197    }
198
199    pub async fn store(&self) -> Result<()> {
200        let dht = self.inner.inner;
201        let network = self.inner.network;
202
203        let mut value = self.make_unsigned_value_ref();
204        let signature = network.sign_tl(&value);
205        value.signature = &signature;
206
207        dht.store_value(network, &ValueRef::Peer(value), self.with_peer_info)
208            .await
209    }
210
211    pub fn store_locally(&self) -> Result<bool, StorageError> {
212        let dht = self.inner.inner;
213        let network = self.inner.network;
214
215        let mut value = self.make_unsigned_value_ref();
216        let signature = network.sign_tl(&value);
217        value.signature = &signature;
218
219        dht.store_value_locally(&ValueRef::Peer(value))
220    }
221
222    pub fn into_signed_value(self) -> PeerValue {
223        let dht = self.inner.inner;
224        let network = self.inner.network;
225
226        let mut value = PeerValue {
227            key: PeerValueKey {
228                name: self.name,
229                peer_id: dht.local_id,
230            },
231            data: self.data.into_boxed_slice(),
232            expires_at: self.at.unwrap_or_else(now_sec) + self.ttl,
233            signature: Box::new([0; 64]),
234        };
235        *value.signature = network.sign_tl(&value);
236        value
237    }
238
239    fn make_unsigned_value_ref(&self) -> PeerValueRef<'_> {
240        PeerValueRef {
241            key: PeerValueKeyRef {
242                name: self.inner.name,
243                peer_id: &self.inner.inner.local_id,
244            },
245            data: &self.data,
246            expires_at: self.at.unwrap_or_else(now_sec) + self.ttl,
247            signature: &[0; 64],
248        }
249    }
250}
251
252impl<'a> std::ops::Deref for DhtQueryWithDataBuilder<'a> {
253    type Target = DhtQueryBuilder<'a>;
254
255    #[inline]
256    fn deref(&self) -> &Self::Target {
257        &self.inner
258    }
259}
260
261impl std::ops::DerefMut for DhtQueryWithDataBuilder<'_> {
262    #[inline]
263    fn deref_mut(&mut self) -> &mut Self::Target {
264        &mut self.inner
265    }
266}
267
268pub struct DhtServiceBackgroundTasks {
269    inner: Arc<DhtInner>,
270}
271
272impl DhtServiceBackgroundTasks {
273    pub fn spawn(self, network: &Network) {
274        self.inner
275            .start_background_tasks(Network::downgrade(network));
276    }
277}
278
279pub struct DhtServiceBuilder {
280    local_id: PeerId,
281    config: Option<DhtConfig>,
282}
283
284impl DhtServiceBuilder {
285    pub fn with_config(mut self, config: DhtConfig) -> Self {
286        self.config = Some(config);
287        self
288    }
289
290    pub fn build(self) -> (DhtServiceBackgroundTasks, DhtService) {
291        let config = self.config.unwrap_or_default();
292
293        let storage = {
294            let mut builder = Storage::builder()
295                .with_max_capacity(config.max_storage_capacity)
296                .with_max_ttl(config.max_stored_value_ttl);
297
298            if let Some(time_to_idle) = config.storage_item_time_to_idle {
299                builder = builder.with_max_idle(time_to_idle);
300            }
301
302            builder.build()
303        };
304
305        let (announced_peers, _) = broadcast::channel(config.announced_peers_channel_capacity);
306
307        let inner = Arc::new(DhtInner {
308            local_id: self.local_id,
309            routing_table: Mutex::new(HandlesRoutingTable::new(self.local_id)),
310            storage,
311            local_peer_info: Mutex::new(None),
312            config,
313            announced_peers,
314            find_value_queries: Default::default(),
315            peer_added: Arc::new(Default::default()),
316        });
317
318        let background_tasks = DhtServiceBackgroundTasks {
319            inner: inner.clone(),
320        };
321
322        (background_tasks, DhtService(inner))
323    }
324}
325
326#[derive(Clone)]
327#[repr(transparent)]
328pub struct DhtService(Arc<DhtInner>);
329
330impl DhtService {
331    #[inline]
332    fn wrap(inner: &Arc<DhtInner>) -> &Self {
333        // SAFETY: `DhtService` has the same memory layout as `Arc<DhtInner>`.
334        unsafe { &*(inner as *const Arc<DhtInner>).cast::<Self>() }
335    }
336
337    pub fn builder(local_id: PeerId) -> DhtServiceBuilder {
338        DhtServiceBuilder {
339            local_id,
340            config: None,
341        }
342    }
343
344    pub fn local_id(&self) -> &PeerId {
345        &self.0.local_id
346    }
347
348    pub fn make_client(&self, network: &Network) -> DhtClient {
349        DhtClient {
350            inner: self.0.clone(),
351            network: network.clone(),
352        }
353    }
354
355    pub fn make_peer_resolver(&self) -> PeerResolverBuilder {
356        PeerResolver::builder(self.clone())
357    }
358
359    pub fn has_peer(&self, peer_id: &PeerId) -> bool {
360        self.0.routing_table.lock().unwrap().contains(peer_id)
361    }
362
363    pub fn store_value_locally(&self, value: &ValueRef<'_>) -> Result<bool, StorageError> {
364        self.0.store_value_locally(value)
365    }
366
367    pub fn insert_merger(
368        &self,
369        group_id: &[u8; 32],
370        merger: Arc<dyn DhtValueMerger>,
371    ) -> Option<Arc<dyn DhtValueMerger>> {
372        self.0.storage.insert_merger(group_id, merger)
373    }
374
375    pub fn remove_merger(&self, group_id: &[u8; 32]) -> Option<Arc<dyn DhtValueMerger>> {
376        self.0.storage.remove_merger(group_id)
377    }
378
379    pub fn peer_added(&self) -> &Arc<Notify> {
380        &self.0.peer_added
381    }
382}
383
384impl Service<ServiceRequest> for DhtService {
385    type QueryResponse = Response;
386    type OnQueryFuture = futures_util::future::Ready<Option<Self::QueryResponse>>;
387    type OnMessageFuture = futures_util::future::Ready<()>;
388
389    #[tracing::instrument(
390        level = "debug",
391        name = "on_dht_query",
392        skip_all,
393        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
394    )]
395    fn on_query(&self, req: ServiceRequest) -> Self::OnQueryFuture {
396        metrics::counter!(METRIC_IN_REQ_TOTAL).increment(1);
397
398        let (constructor, body) = match self.0.try_handle_prefix(&req) {
399            Ok(rest) => rest,
400            Err(e) => {
401                tracing::debug!("failed to deserialize query: {e}");
402                metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
403                return futures_util::future::ready(None);
404            }
405        };
406
407        let response = crate::match_tl_request!(body, tag = constructor, {
408            rpc::FindNode as ref r => {
409                tracing::debug!(key = %PeerId::wrap(&r.key), k = r.k, "find_node");
410                metrics::counter!(METRIC_IN_REQ_FIND_NODE_TOTAL).increment(1);
411
412                let res = self.0.handle_find_node(r);
413                Some(tl_proto::serialize(res))
414            },
415            rpc::FindValue as ref r => {
416                tracing::debug!(key = %PeerId::wrap(&r.key), k = r.k, "find_value");
417                metrics::counter!(METRIC_IN_REQ_FIND_VALUE_TOTAL).increment(1);
418
419                let res = self.0.handle_find_value(r);
420                Some(tl_proto::serialize(res))
421            },
422            rpc::GetNodeInfo as _ => {
423                tracing::debug!("get_node_info");
424                metrics::counter!(METRIC_IN_REQ_GET_NODE_INFO_TOTAL).increment(1);
425
426                self.0.handle_get_node_info().map(tl_proto::serialize)
427            },
428        }, e => {
429            tracing::debug!("failed to deserialize query: {e}");
430            None
431        });
432
433        if response.is_none() {
434            metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
435        }
436
437        futures_util::future::ready(response.map(|body| Response {
438            version: Default::default(),
439            body: Bytes::from(body),
440        }))
441    }
442
443    #[tracing::instrument(
444        level = "debug",
445        name = "on_dht_message",
446        skip_all,
447        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
448    )]
449    fn on_message(&self, req: ServiceRequest) -> Self::OnMessageFuture {
450        metrics::counter!(METRIC_IN_REQ_TOTAL).increment(1);
451
452        let (constructor, body) = match self.0.try_handle_prefix(&req) {
453            Ok(rest) => rest,
454            Err(e) => {
455                tracing::debug!("failed to deserialize message: {e}");
456                metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
457                return futures_util::future::ready(());
458            }
459        };
460
461        let mut has_error = false;
462        crate::match_tl_request!(body, tag = constructor, {
463            rpc::StoreRef<'_> as ref r => {
464                tracing::debug!("store");
465                metrics::counter!(METRIC_IN_REQ_STORE_TOTAL).increment(1);
466
467                if let Err(e) = self.0.handle_store(r) {
468                    tracing::debug!("failed to store value: {e}");
469                    has_error = true;
470                }
471            }
472        }, e => {
473            tracing::debug!("failed to deserialize message: {e}");
474            has_error = true;
475        });
476
477        if has_error {
478            metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
479        }
480
481        futures_util::future::ready(())
482    }
483}
484
485impl Routable for DhtService {
486    fn query_ids(&self) -> impl IntoIterator<Item = u32> {
487        [
488            rpc::WithPeerInfo::TL_ID,
489            rpc::FindNode::TL_ID,
490            rpc::FindValue::TL_ID,
491            rpc::GetNodeInfo::TL_ID,
492        ]
493    }
494
495    fn message_ids(&self) -> impl IntoIterator<Item = u32> {
496        [rpc::WithPeerInfo::TL_ID, rpc::Store::TL_ID]
497    }
498}
499
500struct DhtInner {
501    local_id: PeerId,
502    routing_table: Mutex<HandlesRoutingTable>,
503    storage: Storage,
504    local_peer_info: Mutex<Option<PeerInfo>>,
505    config: DhtConfig,
506    announced_peers: broadcast::Sender<Arc<PeerInfo>>,
507    find_value_queries: QueryCache<Option<Box<Value>>>,
508    peer_added: Arc<Notify>,
509}
510
511impl DhtInner {
512    async fn find_value(
513        &self,
514        network: &Network,
515        key_hash: &[u8; 32],
516        mode: DhtQueryMode,
517    ) -> Option<Box<Value>> {
518        self.find_value_queries
519            .run(key_hash, || {
520                let query = Query::new(
521                    network.clone(),
522                    &self.routing_table.lock().unwrap(),
523                    key_hash,
524                    self.config.max_k,
525                    mode,
526                );
527
528                // NOTE: expression is intentionally split to drop the routing table guard
529                Box::pin(query.find_value())
530            })
531            .await
532    }
533
534    async fn store_value(
535        &self,
536        network: &Network,
537        value: &ValueRef<'_>,
538        with_peer_info: bool,
539    ) -> Result<()> {
540        self.storage.insert(DhtValueSource::Local, value)?;
541
542        let local_peer_info = if with_peer_info {
543            let mut node_info = self.local_peer_info.lock().unwrap();
544            Some(
545                node_info
546                    .get_or_insert_with(|| self.make_local_peer_info(network, now_sec()))
547                    .clone(),
548            )
549        } else {
550            None
551        };
552
553        let query = StoreValue::new(
554            network.clone(),
555            &self.routing_table.lock().unwrap(),
556            value,
557            self.config.max_k,
558            local_peer_info.as_ref(),
559        );
560
561        // NOTE: expression is intentionally split to drop the routing table guard
562        query.run().await;
563        Ok(())
564    }
565
566    fn store_value_locally(&self, value: &ValueRef<'_>) -> Result<bool, StorageError> {
567        self.storage.insert(DhtValueSource::Local, value)
568    }
569
570    // NOTE: Requires the incoming peer info to be valid.
571    fn add_peer_info(&self, network: &Network, peer_info: Arc<PeerInfo>) -> bool {
572        if peer_info.id == self.local_id {
573            return false;
574        }
575
576        let mut routing_table = self.routing_table.lock().unwrap();
577        let added = routing_table.add(
578            peer_info.clone(),
579            self.config.max_k,
580            &self.config.max_peer_info_ttl,
581            |peer_info| network.known_peers().insert(peer_info, false).ok(),
582        );
583
584        if added {
585            self.peer_added.notify_waiters();
586        }
587
588        added
589    }
590
591    fn make_unsigned_peer_value<'a>(
592        &'a self,
593        name: PeerValueKeyName,
594        data: &'a [u8],
595        expires_at: u32,
596    ) -> PeerValueRef<'a> {
597        PeerValueRef {
598            key: PeerValueKeyRef {
599                name,
600                peer_id: &self.local_id,
601            },
602            data,
603            expires_at,
604            signature: &[0; 64],
605        }
606    }
607
608    fn make_local_peer_info(&self, network: &Network, now: u32) -> PeerInfo {
609        network.sign_peer_info(now, self.config.max_peer_info_ttl.as_secs() as u32)
610    }
611
612    fn try_handle_prefix<'a>(&self, req: &'a ServiceRequest) -> Result<(u32, &'a [u8])> {
613        let mut body = req.as_ref();
614        anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof);
615
616        // NOTE: read constructor without advancing the body
617        let mut constructor = std::convert::identity(body).get_u32_le();
618
619        if constructor == rpc::WithPeerInfo::TL_ID {
620            metrics::counter!(METRIC_IN_REQ_WITH_PEER_INFO_TOTAL).increment(1);
621
622            let peer_info = rpc::WithPeerInfo::read_from(&mut body)?.peer_info;
623            anyhow::ensure!(
624                peer_info.id == req.metadata.peer_id,
625                "suggested peer ID does not belong to the sender"
626            );
627
628            anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof);
629            self.announced_peers.send(Arc::new(peer_info)).ok();
630
631            // NOTE: read constructor without advancing the body
632            constructor = std::convert::identity(body).get_u32_le();
633        }
634
635        Ok((constructor, body))
636    }
637
638    fn handle_store(&self, req: &rpc::StoreRef<'_>) -> Result<bool, StorageError> {
639        self.storage.insert(DhtValueSource::Remote, &req.value)
640    }
641
642    fn handle_find_node(&self, req: &rpc::FindNode) -> NodeResponse {
643        let nodes = self
644            .routing_table
645            .lock()
646            .unwrap()
647            .closest(&req.key, (req.k as usize).min(self.config.max_k));
648
649        NodeResponse { nodes }
650    }
651
652    fn handle_find_value(&self, req: &rpc::FindValue) -> ValueResponseRaw {
653        if let Some(value) = self.storage.get(&req.key) {
654            ValueResponseRaw::Found(value)
655        } else {
656            let nodes = self
657                .routing_table
658                .lock()
659                .unwrap()
660                .closest(&req.key, (req.k as usize).min(self.config.max_k));
661
662            ValueResponseRaw::NotFound(nodes)
663        }
664    }
665
666    fn handle_get_node_info(&self) -> Option<NodeInfoResponse> {
667        self.local_peer_info
668            .lock()
669            .unwrap()
670            .clone()
671            .map(|info| NodeInfoResponse { info })
672    }
673}
674
675fn random_key_at_distance(from: &PeerId, distance: usize, rng: &mut impl RngCore) -> PeerId {
676    let distance = MAX_XOR_DISTANCE - distance;
677
678    let mut result = *from;
679
680    let byte_offset = distance / 8;
681    rng.fill_bytes(&mut result.0[byte_offset..]);
682
683    let bit_offset = distance % 8;
684    if bit_offset != 0 {
685        let mask = 0xff >> bit_offset;
686        result.0[byte_offset] ^= (result.0[byte_offset] ^ from.0[byte_offset]) & !mask;
687    }
688
689    result
690}
691
692pub fn xor_distance(left: &PeerId, right: &PeerId) -> usize {
693    for (i, (left, right)) in std::iter::zip(left.0.chunks(8), right.0.chunks(8)).enumerate() {
694        let left = u64::from_be_bytes(left.try_into().unwrap());
695        let right = u64::from_be_bytes(right.try_into().unwrap());
696        let diff = left ^ right;
697        if diff != 0 {
698            return MAX_XOR_DISTANCE - (i * 64 + diff.leading_zeros() as usize);
699        }
700    }
701
702    0
703}
704
705const MAX_XOR_DISTANCE: usize = 256;
706
707#[derive(Debug, thiserror::Error)]
708pub enum FindValueError {
709    #[error("failed to deserialize value: {0}")]
710    InvalidData(#[from] tl_proto::TlError),
711    #[error("value not found")]
712    NotFound,
713}
714
715#[cfg(test)]
716mod tests {
717    use super::*;
718
719    #[test]
720    fn proper_random_keys() {
721        let peer_id = rand::random();
722        let random_id = random_key_at_distance(&peer_id, 20, &mut rand::rng());
723        println!("{peer_id}");
724        println!("{random_id}");
725
726        let distance = xor_distance(&peer_id, &random_id);
727        println!("{distance}");
728        assert!(distance <= 23);
729    }
730}