tycho_network/dht/
mod.rs

1use std::sync::{Arc, Mutex};
2
3use anyhow::Result;
4use bytes::{Buf, Bytes};
5use rand::RngCore;
6use tl_proto::TlRead;
7use tokio::sync::{Notify, broadcast};
8use tycho_util::realloc_box_enum;
9use tycho_util::time::now_sec;
10
11pub use self::config::DhtConfig;
12pub use self::peer_resolver::{
13    PeerResolver, PeerResolverBuilder, PeerResolverConfig, PeerResolverHandle,
14};
15pub use self::query::DhtQueryMode;
16use self::query::{Query, QueryCache, StoreValue};
17use self::routing::HandlesRoutingTable;
18use self::storage::Storage;
19pub use self::storage::{DhtValueMerger, DhtValueSource, StorageError};
20use crate::network::Network;
21use crate::proto::dht::{
22    NodeInfoResponse, NodeResponse, PeerValue, PeerValueKey, PeerValueKeyName, PeerValueKeyRef,
23    PeerValueRef, Value, ValueRef, ValueResponseRaw, rpc,
24};
25use crate::types::{PeerId, PeerInfo, Request, Response, Service, ServiceRequest};
26use crate::util::{NetworkExt, Routable};
27
28mod background_tasks;
29mod config;
30mod peer_resolver;
31mod query;
32mod routing;
33mod storage;
34
35// Counters
36const METRIC_IN_REQ_TOTAL: &str = "tycho_net_dht_in_req_total";
37const METRIC_IN_REQ_FAIL_TOTAL: &str = "tycho_net_dht_in_req_fail_total";
38
39const METRIC_IN_REQ_WITH_PEER_INFO_TOTAL: &str = "tycho_net_dht_in_req_with_peer_info_total";
40const METRIC_IN_REQ_FIND_NODE_TOTAL: &str = "tycho_net_dht_in_req_find_node_total";
41const METRIC_IN_REQ_FIND_VALUE_TOTAL: &str = "tycho_net_dht_in_req_find_value_total";
42const METRIC_IN_REQ_GET_NODE_INFO_TOTAL: &str = "tycho_net_dht_in_req_get_node_info_total";
43const METRIC_IN_REQ_STORE_TOTAL: &str = "tycho_net_dht_in_req_store_value_total";
44
45#[derive(Clone)]
46pub struct DhtClient {
47    inner: Arc<DhtInner>,
48    network: Network,
49}
50
51impl DhtClient {
52    #[inline]
53    pub fn network(&self) -> &Network {
54        &self.network
55    }
56
57    #[inline]
58    pub fn service(&self) -> &DhtService {
59        DhtService::wrap(&self.inner)
60    }
61
62    pub fn add_peer(&self, peer: Arc<PeerInfo>) -> Result<bool> {
63        anyhow::ensure!(peer.verify(now_sec()), "invalid peer info");
64        let added = self.inner.add_peer_info(&self.network, peer);
65        Ok(added)
66    }
67
68    pub fn add_allow_outdated_peer(&self, peer: Arc<PeerInfo>) -> Result<bool> {
69        anyhow::ensure!(peer.verify(now_sec()), "invalid peer info");
70        let added = self.inner.add_allow_outdated_peer_info(&self.network, peer);
71        Ok(added)
72    }
73
74    pub async fn get_node_info(&self, peer_id: &PeerId) -> Result<PeerInfo> {
75        let res = self
76            .network
77            .query(peer_id, Request::from_tl(rpc::GetNodeInfo))
78            .await?;
79        let NodeInfoResponse { info } = res.parse_tl()?;
80        Ok(info)
81    }
82
83    pub fn entry(&self, name: PeerValueKeyName) -> DhtQueryBuilder<'_> {
84        DhtQueryBuilder {
85            inner: &self.inner,
86            network: &self.network,
87            name,
88            idx: 0,
89        }
90    }
91
92    /// Find a value by its key hash.
93    ///
94    /// This is quite a low-level method, so it is recommended to use [`DhtClient::entry`].
95    pub async fn find_value(&self, key_hash: &[u8; 32], mode: DhtQueryMode) -> Option<Box<Value>> {
96        self.inner.find_value(&self.network, key_hash, mode).await
97    }
98}
99
100#[derive(Clone, Copy)]
101pub struct DhtQueryBuilder<'a> {
102    inner: &'a DhtInner,
103    network: &'a Network,
104    name: PeerValueKeyName,
105    idx: u32,
106}
107
108impl<'a> DhtQueryBuilder<'a> {
109    #[inline]
110    pub fn with_idx(&mut self, idx: u32) -> &mut Self {
111        self.idx = idx;
112        self
113    }
114
115    pub async fn find_value<T>(&self, peer_id: &PeerId) -> Result<T, FindValueError>
116    where
117        for<'tl> T: tl_proto::TlRead<'tl>,
118    {
119        let key_hash = tl_proto::hash(PeerValueKeyRef {
120            name: self.name,
121            peer_id,
122        });
123
124        match self
125            .inner
126            .find_value(self.network, &key_hash, DhtQueryMode::Closest)
127            .await
128        {
129            Some(value) => match value.as_ref() {
130                Value::Peer(value) => {
131                    tl_proto::deserialize(&value.data).map_err(FindValueError::InvalidData)
132                }
133                Value::Merged(_) => Err(FindValueError::InvalidData(
134                    tl_proto::TlError::UnknownConstructor,
135                )),
136            },
137            None => Err(FindValueError::NotFound),
138        }
139    }
140
141    pub async fn find_peer_value_raw(
142        &self,
143        peer_id: &PeerId,
144    ) -> Result<Box<PeerValue>, FindValueError> {
145        let key_hash = tl_proto::hash(PeerValueKeyRef {
146            name: self.name,
147            peer_id,
148        });
149
150        match self
151            .inner
152            .find_value(self.network, &key_hash, DhtQueryMode::Closest)
153            .await
154        {
155            Some(value) => {
156                realloc_box_enum!(value, {
157                    Value::Peer(value) => Box::new(value) => Ok(value),
158                    Value::Merged(_) => Err(FindValueError::InvalidData(
159                        tl_proto::TlError::UnknownConstructor,
160                    )),
161                })
162            }
163            None => Err(FindValueError::NotFound),
164        }
165    }
166
167    pub fn with_data<T>(&self, data: T) -> DhtQueryWithDataBuilder<'a>
168    where
169        T: tl_proto::TlWrite,
170    {
171        DhtQueryWithDataBuilder {
172            inner: *self,
173            data: tl_proto::serialize(&data),
174            at: None,
175            ttl: self.inner.config.max_stored_value_ttl.as_secs() as _,
176            with_peer_info: false,
177        }
178    }
179}
180
181pub struct DhtQueryWithDataBuilder<'a> {
182    inner: DhtQueryBuilder<'a>,
183    data: Vec<u8>,
184    at: Option<u32>,
185    ttl: u32,
186    with_peer_info: bool,
187}
188
189impl DhtQueryWithDataBuilder<'_> {
190    pub fn with_time(&mut self, at: u32) -> &mut Self {
191        self.at = Some(at);
192        self
193    }
194
195    pub fn with_ttl(&mut self, ttl: u32) -> &mut Self {
196        self.ttl = ttl;
197        self
198    }
199
200    pub fn with_peer_info(&mut self, with_peer_info: bool) -> &mut Self {
201        self.with_peer_info = with_peer_info;
202        self
203    }
204
205    pub async fn store(&self) -> Result<()> {
206        let dht = self.inner.inner;
207        let network = self.inner.network;
208
209        let mut value = self.make_unsigned_value_ref();
210        let signature = network.sign_tl(&value);
211        value.signature = &signature;
212
213        dht.store_value(network, &ValueRef::Peer(value), self.with_peer_info)
214            .await
215    }
216
217    pub fn store_locally(&self) -> Result<bool, StorageError> {
218        let dht = self.inner.inner;
219        let network = self.inner.network;
220
221        let mut value = self.make_unsigned_value_ref();
222        let signature = network.sign_tl(&value);
223        value.signature = &signature;
224
225        dht.store_value_locally(&ValueRef::Peer(value))
226    }
227
228    pub fn into_signed_value(self) -> PeerValue {
229        let dht = self.inner.inner;
230        let network = self.inner.network;
231
232        let mut value = PeerValue {
233            key: PeerValueKey {
234                name: self.name,
235                peer_id: dht.local_id,
236            },
237            data: self.data.into_boxed_slice(),
238            expires_at: self.at.unwrap_or_else(now_sec) + self.ttl,
239            signature: Box::new([0; 64]),
240        };
241        *value.signature = network.sign_tl(&value);
242        value
243    }
244
245    fn make_unsigned_value_ref(&self) -> PeerValueRef<'_> {
246        PeerValueRef {
247            key: PeerValueKeyRef {
248                name: self.inner.name,
249                peer_id: &self.inner.inner.local_id,
250            },
251            data: &self.data,
252            expires_at: self.at.unwrap_or_else(now_sec) + self.ttl,
253            signature: &[0; 64],
254        }
255    }
256}
257
258impl<'a> std::ops::Deref for DhtQueryWithDataBuilder<'a> {
259    type Target = DhtQueryBuilder<'a>;
260
261    #[inline]
262    fn deref(&self) -> &Self::Target {
263        &self.inner
264    }
265}
266
267impl std::ops::DerefMut for DhtQueryWithDataBuilder<'_> {
268    #[inline]
269    fn deref_mut(&mut self) -> &mut Self::Target {
270        &mut self.inner
271    }
272}
273
274pub struct DhtServiceBackgroundTasks {
275    inner: Arc<DhtInner>,
276}
277
278impl DhtServiceBackgroundTasks {
279    pub fn spawn_without_bootstrap(self, network: &Network) {
280        self.inner
281            .start_background_tasks(Network::downgrade(network), Vec::new());
282    }
283
284    pub fn spawn<I>(self, network: &Network, bootstrap_peers: I) -> Result<usize>
285    where
286        I: IntoIterator<Item: std::ops::Deref<Target = PeerInfo>>,
287    {
288        let mut peers = Vec::new();
289        for peer in bootstrap_peers {
290            let peer = &*peer;
291
292            anyhow::ensure!(
293                peer.verify(now_sec()),
294                "invalid peer info for id {}",
295                peer.id
296            );
297            let peer = Arc::new(peer.clone());
298            if self.inner.add_peer_info(network, peer.clone()) {
299                peers.push(peer);
300            }
301        }
302
303        let count = peers.len();
304        self.inner
305            .start_background_tasks(Network::downgrade(network), peers);
306        Ok(count)
307    }
308}
309
310pub struct DhtServiceBuilder {
311    local_id: PeerId,
312    config: Option<DhtConfig>,
313}
314
315impl DhtServiceBuilder {
316    pub fn with_config(mut self, config: DhtConfig) -> Self {
317        self.config = Some(config);
318        self
319    }
320
321    pub fn build(self) -> (DhtServiceBackgroundTasks, DhtService) {
322        let config = self.config.unwrap_or_default();
323
324        let storage = {
325            let mut builder = Storage::builder()
326                .with_max_capacity(config.max_storage_capacity)
327                .with_max_ttl(config.max_stored_value_ttl);
328
329            if let Some(time_to_idle) = config.storage_item_time_to_idle {
330                builder = builder.with_max_idle(time_to_idle);
331            }
332
333            builder.build()
334        };
335
336        let (announced_peers, _) = broadcast::channel(config.announced_peers_channel_capacity);
337
338        let inner = Arc::new(DhtInner {
339            local_id: self.local_id,
340            routing_table: Mutex::new(HandlesRoutingTable::new(self.local_id)),
341            storage,
342            local_peer_info: Mutex::new(None),
343            config,
344            announced_peers,
345            find_value_queries: Default::default(),
346            peer_added: Arc::new(Default::default()),
347        });
348
349        let background_tasks = DhtServiceBackgroundTasks {
350            inner: inner.clone(),
351        };
352
353        (background_tasks, DhtService(inner))
354    }
355}
356
357#[derive(Clone)]
358#[repr(transparent)]
359pub struct DhtService(Arc<DhtInner>);
360
361impl DhtService {
362    #[inline]
363    fn wrap(inner: &Arc<DhtInner>) -> &Self {
364        // SAFETY: `DhtService` has the same memory layout as `Arc<DhtInner>`.
365        unsafe { &*(inner as *const Arc<DhtInner>).cast::<Self>() }
366    }
367
368    pub fn builder(local_id: PeerId) -> DhtServiceBuilder {
369        DhtServiceBuilder {
370            local_id,
371            config: None,
372        }
373    }
374
375    pub fn local_id(&self) -> &PeerId {
376        &self.0.local_id
377    }
378
379    pub fn make_client(&self, network: &Network) -> DhtClient {
380        DhtClient {
381            inner: self.0.clone(),
382            network: network.clone(),
383        }
384    }
385
386    pub fn make_peer_resolver(&self) -> PeerResolverBuilder {
387        PeerResolver::builder(self.clone())
388    }
389
390    pub fn has_peer(&self, peer_id: &PeerId) -> bool {
391        self.0.routing_table.lock().unwrap().contains(peer_id)
392    }
393
394    pub fn find_local_closest(&self, key: &[u8; 32], count: usize) -> Vec<Arc<PeerInfo>> {
395        self.0.routing_table.lock().unwrap().closest(key, count)
396    }
397
398    pub fn store_value_locally(&self, value: &ValueRef<'_>) -> Result<bool, StorageError> {
399        self.0.store_value_locally(value)
400    }
401
402    pub fn insert_merger(
403        &self,
404        group_id: &[u8; 32],
405        merger: Arc<dyn DhtValueMerger>,
406    ) -> Option<Arc<dyn DhtValueMerger>> {
407        self.0.storage.insert_merger(group_id, merger)
408    }
409
410    pub fn remove_merger(&self, group_id: &[u8; 32]) -> Option<Arc<dyn DhtValueMerger>> {
411        self.0.storage.remove_merger(group_id)
412    }
413
414    pub fn peer_added(&self) -> &Arc<Notify> {
415        &self.0.peer_added
416    }
417}
418
419impl Service<ServiceRequest> for DhtService {
420    type QueryResponse = Response;
421    type OnQueryFuture = futures_util::future::Ready<Option<Self::QueryResponse>>;
422    type OnMessageFuture = futures_util::future::Ready<()>;
423
424    #[tracing::instrument(
425        level = "debug",
426        name = "on_dht_query",
427        skip_all,
428        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
429    )]
430    fn on_query(&self, req: ServiceRequest) -> Self::OnQueryFuture {
431        metrics::counter!(METRIC_IN_REQ_TOTAL).increment(1);
432
433        let (constructor, body) = match self.0.try_handle_prefix(&req) {
434            Ok(rest) => rest,
435            Err(e) => {
436                tracing::debug!("failed to deserialize query: {e}");
437                metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
438                return futures_util::future::ready(None);
439            }
440        };
441
442        let response = crate::match_tl_request!(body, tag = constructor, {
443            rpc::FindNode as ref r => {
444                tracing::debug!(key = %PeerId::wrap(&r.key), k = r.k, "find_node");
445                metrics::counter!(METRIC_IN_REQ_FIND_NODE_TOTAL).increment(1);
446
447                let res = self.0.handle_find_node(r);
448                Some(tl_proto::serialize(res))
449            },
450            rpc::FindValue as ref r => {
451                tracing::debug!(key = %PeerId::wrap(&r.key), k = r.k, "find_value");
452                metrics::counter!(METRIC_IN_REQ_FIND_VALUE_TOTAL).increment(1);
453
454                let res = self.0.handle_find_value(r);
455                Some(tl_proto::serialize(res))
456            },
457            rpc::GetNodeInfo as _ => {
458                tracing::debug!("get_node_info");
459                metrics::counter!(METRIC_IN_REQ_GET_NODE_INFO_TOTAL).increment(1);
460
461                self.0.handle_get_node_info().map(tl_proto::serialize)
462            },
463        }, e => {
464            tracing::debug!("failed to deserialize query: {e}");
465            None
466        });
467
468        if response.is_none() {
469            metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
470        }
471
472        futures_util::future::ready(response.map(|body| Response {
473            version: Default::default(),
474            body: Bytes::from(body),
475        }))
476    }
477
478    #[tracing::instrument(
479        level = "debug",
480        name = "on_dht_message",
481        skip_all,
482        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
483    )]
484    fn on_message(&self, req: ServiceRequest) -> Self::OnMessageFuture {
485        metrics::counter!(METRIC_IN_REQ_TOTAL).increment(1);
486
487        let (constructor, body) = match self.0.try_handle_prefix(&req) {
488            Ok(rest) => rest,
489            Err(e) => {
490                tracing::debug!("failed to deserialize message: {e}");
491                metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
492                return futures_util::future::ready(());
493            }
494        };
495
496        let mut has_error = false;
497        crate::match_tl_request!(body, tag = constructor, {
498            rpc::StoreRef<'_> as ref r => {
499                tracing::debug!("store");
500                metrics::counter!(METRIC_IN_REQ_STORE_TOTAL).increment(1);
501
502                if let Err(e) = self.0.handle_store(r) {
503                    tracing::debug!("failed to store value: {e}");
504                    has_error = true;
505                }
506            }
507        }, e => {
508            tracing::debug!("failed to deserialize message: {e}");
509            has_error = true;
510        });
511
512        if has_error {
513            metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
514        }
515
516        futures_util::future::ready(())
517    }
518}
519
520impl Routable for DhtService {
521    fn query_ids(&self) -> impl IntoIterator<Item = u32> {
522        [
523            rpc::WithPeerInfo::TL_ID,
524            rpc::FindNode::TL_ID,
525            rpc::FindValue::TL_ID,
526            rpc::GetNodeInfo::TL_ID,
527        ]
528    }
529
530    fn message_ids(&self) -> impl IntoIterator<Item = u32> {
531        [rpc::WithPeerInfo::TL_ID, rpc::Store::TL_ID]
532    }
533}
534
535struct DhtInner {
536    local_id: PeerId,
537    routing_table: Mutex<HandlesRoutingTable>,
538    storage: Storage,
539    local_peer_info: Mutex<Option<PeerInfo>>,
540    config: DhtConfig,
541    announced_peers: broadcast::Sender<Arc<PeerInfo>>,
542    find_value_queries: QueryCache<Option<Box<Value>>>,
543    peer_added: Arc<Notify>,
544}
545
546impl DhtInner {
547    async fn find_value(
548        &self,
549        network: &Network,
550        key_hash: &[u8; 32],
551        mode: DhtQueryMode,
552    ) -> Option<Box<Value>> {
553        self.find_value_queries
554            .run(key_hash, || {
555                let query = Query::new(
556                    network.clone(),
557                    &self.routing_table.lock().unwrap(),
558                    key_hash,
559                    self.config.max_k,
560                    mode,
561                );
562
563                // NOTE: expression is intentionally split to drop the routing table guard
564                Box::pin(query.find_value())
565            })
566            .await
567    }
568
569    async fn store_value(
570        &self,
571        network: &Network,
572        value: &ValueRef<'_>,
573        with_peer_info: bool,
574    ) -> Result<()> {
575        self.storage.insert(DhtValueSource::Local, value)?;
576
577        let local_peer_info = if with_peer_info {
578            let mut node_info = self.local_peer_info.lock().unwrap();
579            Some(
580                node_info
581                    .get_or_insert_with(|| self.make_local_peer_info(network, now_sec()))
582                    .clone(),
583            )
584        } else {
585            None
586        };
587
588        let query = StoreValue::new(
589            network.clone(),
590            &self.routing_table.lock().unwrap(),
591            value,
592            self.config.max_k,
593            local_peer_info.as_ref(),
594        );
595
596        // NOTE: expression is intentionally split to drop the routing table guard
597        query.run().await;
598        Ok(())
599    }
600
601    fn store_value_locally(&self, value: &ValueRef<'_>) -> Result<bool, StorageError> {
602        self.storage.insert(DhtValueSource::Local, value)
603    }
604
605    // NOTE: Requires the incoming peer info to be valid.
606    fn add_peer_info(&self, network: &Network, peer_info: Arc<PeerInfo>) -> bool {
607        if peer_info.id == self.local_id {
608            return false;
609        }
610
611        let mut routing_table = self.routing_table.lock().unwrap();
612        let added = routing_table.add(
613            peer_info.clone(),
614            self.config.max_k,
615            &self.config.max_peer_info_ttl,
616            |peer_info| network.known_peers().insert(peer_info, false).ok(),
617        );
618        drop(routing_table);
619
620        if added {
621            self.peer_added.notify_waiters();
622        }
623
624        added
625    }
626
627    fn add_allow_outdated_peer_info(&self, network: &Network, peer_info: Arc<PeerInfo>) -> bool {
628        if peer_info.id == self.local_id {
629            return false;
630        }
631
632        let mut added = false;
633
634        let mut routing_table = self.routing_table.lock().unwrap();
635        if !routing_table.contains(&peer_info.id) {
636            added = routing_table.add(
637                peer_info.clone(),
638                self.config.max_k,
639                &self.config.max_peer_info_ttl,
640                |peer_info| {
641                    network
642                        .known_peers()
643                        .insert_allow_outdated(peer_info, false)
644                        .ok()
645                },
646            );
647        }
648        drop(routing_table);
649
650        if added {
651            self.peer_added.notify_waiters();
652        }
653
654        added
655    }
656
657    fn make_unsigned_peer_value<'a>(
658        &'a self,
659        name: PeerValueKeyName,
660        data: &'a [u8],
661        expires_at: u32,
662    ) -> PeerValueRef<'a> {
663        PeerValueRef {
664            key: PeerValueKeyRef {
665                name,
666                peer_id: &self.local_id,
667            },
668            data,
669            expires_at,
670            signature: &[0; 64],
671        }
672    }
673
674    fn make_local_peer_info(&self, network: &Network, now: u32) -> PeerInfo {
675        network.sign_peer_info(now, self.config.max_peer_info_ttl.as_secs() as u32)
676    }
677
678    fn try_handle_prefix<'a>(&self, req: &'a ServiceRequest) -> Result<(u32, &'a [u8])> {
679        let mut body = req.as_ref();
680        anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof);
681
682        // NOTE: read constructor without advancing the body
683        let mut constructor = std::convert::identity(body).get_u32_le();
684
685        if constructor == rpc::WithPeerInfo::TL_ID {
686            metrics::counter!(METRIC_IN_REQ_WITH_PEER_INFO_TOTAL).increment(1);
687
688            let peer_info = rpc::WithPeerInfo::read_from(&mut body)?.peer_info;
689            anyhow::ensure!(
690                peer_info.id == req.metadata.peer_id,
691                "suggested peer ID does not belong to the sender"
692            );
693
694            anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof);
695            self.announced_peers.send(Arc::new(peer_info)).ok();
696
697            // NOTE: read constructor without advancing the body
698            constructor = std::convert::identity(body).get_u32_le();
699        }
700
701        Ok((constructor, body))
702    }
703
704    fn handle_store(&self, req: &rpc::StoreRef<'_>) -> Result<bool, StorageError> {
705        self.storage.insert(DhtValueSource::Remote, &req.value)
706    }
707
708    fn handle_find_node(&self, req: &rpc::FindNode) -> NodeResponse {
709        let nodes = self
710            .routing_table
711            .lock()
712            .unwrap()
713            .closest(&req.key, (req.k as usize).min(self.config.max_k));
714
715        NodeResponse { nodes }
716    }
717
718    fn handle_find_value(&self, req: &rpc::FindValue) -> ValueResponseRaw {
719        if let Some(value) = self.storage.get(&req.key) {
720            ValueResponseRaw::Found(value)
721        } else {
722            let nodes = self
723                .routing_table
724                .lock()
725                .unwrap()
726                .closest(&req.key, (req.k as usize).min(self.config.max_k));
727
728            ValueResponseRaw::NotFound(nodes)
729        }
730    }
731
732    fn handle_get_node_info(&self) -> Option<NodeInfoResponse> {
733        self.local_peer_info
734            .lock()
735            .unwrap()
736            .clone()
737            .map(|info| NodeInfoResponse { info })
738    }
739}
740
741fn random_key_at_distance(from: &PeerId, distance: usize, rng: &mut impl RngCore) -> PeerId {
742    let distance = MAX_XOR_DISTANCE - distance;
743
744    let mut result = *from;
745
746    let byte_offset = distance / 8;
747    rng.fill_bytes(&mut result.0[byte_offset..]);
748
749    let bit_offset = distance % 8;
750    if bit_offset != 0 {
751        let mask = 0xff >> bit_offset;
752        result.0[byte_offset] ^= (result.0[byte_offset] ^ from.0[byte_offset]) & !mask;
753    }
754
755    result
756}
757
758pub fn xor_distance(left: &PeerId, right: &PeerId) -> usize {
759    for (i, (left, right)) in std::iter::zip(left.0.chunks(8), right.0.chunks(8)).enumerate() {
760        let left = u64::from_be_bytes(left.try_into().unwrap());
761        let right = u64::from_be_bytes(right.try_into().unwrap());
762        let diff = left ^ right;
763        if diff != 0 {
764            return MAX_XOR_DISTANCE - (i * 64 + diff.leading_zeros() as usize);
765        }
766    }
767
768    0
769}
770
771const MAX_XOR_DISTANCE: usize = 256;
772
773#[derive(Debug, thiserror::Error)]
774pub enum FindValueError {
775    #[error("failed to deserialize value: {0}")]
776    InvalidData(#[from] tl_proto::TlError),
777    #[error("value not found")]
778    NotFound,
779}
780
781#[cfg(test)]
782mod tests {
783    use super::*;
784
785    #[test]
786    fn proper_random_keys() {
787        let peer_id = rand::random();
788        let random_id = random_key_at_distance(&peer_id, 20, &mut rand::rng());
789        println!("{peer_id}");
790        println!("{random_id}");
791
792        let distance = xor_distance(&peer_id, &random_id);
793        println!("{distance}");
794        assert!(distance <= 23);
795    }
796}