1use std::sync::{Arc, Mutex};
2
3use anyhow::Result;
4use bytes::{Buf, Bytes};
5use rand::RngCore;
6use tl_proto::TlRead;
7use tokio::sync::{Notify, broadcast};
8use tycho_util::realloc_box_enum;
9use tycho_util::time::now_sec;
10
11pub use self::config::DhtConfig;
12pub use self::peer_resolver::{
13 PeerResolver, PeerResolverBuilder, PeerResolverConfig, PeerResolverHandle,
14};
15pub use self::query::DhtQueryMode;
16use self::query::{Query, QueryCache, StoreValue};
17use self::routing::HandlesRoutingTable;
18use self::storage::Storage;
19pub use self::storage::{DhtValueMerger, DhtValueSource, StorageError};
20use crate::network::Network;
21use crate::proto::dht::{
22 NodeInfoResponse, NodeResponse, PeerValue, PeerValueKey, PeerValueKeyName, PeerValueKeyRef,
23 PeerValueRef, Value, ValueRef, ValueResponseRaw, rpc,
24};
25use crate::types::{PeerId, PeerInfo, Request, Response, Service, ServiceRequest};
26use crate::util::{NetworkExt, Routable};
27
28mod background_tasks;
29mod config;
30mod peer_resolver;
31mod query;
32mod routing;
33mod storage;
34
35const METRIC_IN_REQ_TOTAL: &str = "tycho_net_dht_in_req_total";
37const METRIC_IN_REQ_FAIL_TOTAL: &str = "tycho_net_dht_in_req_fail_total";
38
39const METRIC_IN_REQ_WITH_PEER_INFO_TOTAL: &str = "tycho_net_dht_in_req_with_peer_info_total";
40const METRIC_IN_REQ_FIND_NODE_TOTAL: &str = "tycho_net_dht_in_req_find_node_total";
41const METRIC_IN_REQ_FIND_VALUE_TOTAL: &str = "tycho_net_dht_in_req_find_value_total";
42const METRIC_IN_REQ_GET_NODE_INFO_TOTAL: &str = "tycho_net_dht_in_req_get_node_info_total";
43const METRIC_IN_REQ_STORE_TOTAL: &str = "tycho_net_dht_in_req_store_value_total";
44
45#[derive(Clone)]
46pub struct DhtClient {
47 inner: Arc<DhtInner>,
48 network: Network,
49}
50
51impl DhtClient {
52 #[inline]
53 pub fn network(&self) -> &Network {
54 &self.network
55 }
56
57 #[inline]
58 pub fn service(&self) -> &DhtService {
59 DhtService::wrap(&self.inner)
60 }
61
62 pub fn add_peer(&self, peer: Arc<PeerInfo>) -> Result<bool> {
63 anyhow::ensure!(peer.verify(now_sec()), "invalid peer info");
64 let added = self.inner.add_peer_info(&self.network, peer);
65 Ok(added)
66 }
67
68 pub async fn get_node_info(&self, peer_id: &PeerId) -> Result<PeerInfo> {
69 let res = self
70 .network
71 .query(peer_id, Request::from_tl(rpc::GetNodeInfo))
72 .await?;
73 let NodeInfoResponse { info } = res.parse_tl()?;
74 Ok(info)
75 }
76
77 pub fn entry(&self, name: PeerValueKeyName) -> DhtQueryBuilder<'_> {
78 DhtQueryBuilder {
79 inner: &self.inner,
80 network: &self.network,
81 name,
82 idx: 0,
83 }
84 }
85
86 pub async fn find_value(&self, key_hash: &[u8; 32], mode: DhtQueryMode) -> Option<Box<Value>> {
90 self.inner.find_value(&self.network, key_hash, mode).await
91 }
92}
93
94#[derive(Clone, Copy)]
95pub struct DhtQueryBuilder<'a> {
96 inner: &'a DhtInner,
97 network: &'a Network,
98 name: PeerValueKeyName,
99 idx: u32,
100}
101
102impl<'a> DhtQueryBuilder<'a> {
103 #[inline]
104 pub fn with_idx(&mut self, idx: u32) -> &mut Self {
105 self.idx = idx;
106 self
107 }
108
109 pub async fn find_value<T>(&self, peer_id: &PeerId) -> Result<T, FindValueError>
110 where
111 for<'tl> T: tl_proto::TlRead<'tl>,
112 {
113 let key_hash = tl_proto::hash(PeerValueKeyRef {
114 name: self.name,
115 peer_id,
116 });
117
118 match self
119 .inner
120 .find_value(self.network, &key_hash, DhtQueryMode::Closest)
121 .await
122 {
123 Some(value) => match value.as_ref() {
124 Value::Peer(value) => {
125 tl_proto::deserialize(&value.data).map_err(FindValueError::InvalidData)
126 }
127 Value::Merged(_) => Err(FindValueError::InvalidData(
128 tl_proto::TlError::UnknownConstructor,
129 )),
130 },
131 None => Err(FindValueError::NotFound),
132 }
133 }
134
135 pub async fn find_peer_value_raw(
136 &self,
137 peer_id: &PeerId,
138 ) -> Result<Box<PeerValue>, FindValueError> {
139 let key_hash = tl_proto::hash(PeerValueKeyRef {
140 name: self.name,
141 peer_id,
142 });
143
144 match self
145 .inner
146 .find_value(self.network, &key_hash, DhtQueryMode::Closest)
147 .await
148 {
149 Some(value) => {
150 realloc_box_enum!(value, {
151 Value::Peer(value) => Box::new(value) => Ok(value),
152 Value::Merged(_) => Err(FindValueError::InvalidData(
153 tl_proto::TlError::UnknownConstructor,
154 )),
155 })
156 }
157 None => Err(FindValueError::NotFound),
158 }
159 }
160
161 pub fn with_data<T>(&self, data: T) -> DhtQueryWithDataBuilder<'a>
162 where
163 T: tl_proto::TlWrite,
164 {
165 DhtQueryWithDataBuilder {
166 inner: *self,
167 data: tl_proto::serialize(&data),
168 at: None,
169 ttl: self.inner.config.max_stored_value_ttl.as_secs() as _,
170 with_peer_info: false,
171 }
172 }
173}
174
175pub struct DhtQueryWithDataBuilder<'a> {
176 inner: DhtQueryBuilder<'a>,
177 data: Vec<u8>,
178 at: Option<u32>,
179 ttl: u32,
180 with_peer_info: bool,
181}
182
183impl DhtQueryWithDataBuilder<'_> {
184 pub fn with_time(&mut self, at: u32) -> &mut Self {
185 self.at = Some(at);
186 self
187 }
188
189 pub fn with_ttl(&mut self, ttl: u32) -> &mut Self {
190 self.ttl = ttl;
191 self
192 }
193
194 pub fn with_peer_info(&mut self, with_peer_info: bool) -> &mut Self {
195 self.with_peer_info = with_peer_info;
196 self
197 }
198
199 pub async fn store(&self) -> Result<()> {
200 let dht = self.inner.inner;
201 let network = self.inner.network;
202
203 let mut value = self.make_unsigned_value_ref();
204 let signature = network.sign_tl(&value);
205 value.signature = &signature;
206
207 dht.store_value(network, &ValueRef::Peer(value), self.with_peer_info)
208 .await
209 }
210
211 pub fn store_locally(&self) -> Result<bool, StorageError> {
212 let dht = self.inner.inner;
213 let network = self.inner.network;
214
215 let mut value = self.make_unsigned_value_ref();
216 let signature = network.sign_tl(&value);
217 value.signature = &signature;
218
219 dht.store_value_locally(&ValueRef::Peer(value))
220 }
221
222 pub fn into_signed_value(self) -> PeerValue {
223 let dht = self.inner.inner;
224 let network = self.inner.network;
225
226 let mut value = PeerValue {
227 key: PeerValueKey {
228 name: self.name,
229 peer_id: dht.local_id,
230 },
231 data: self.data.into_boxed_slice(),
232 expires_at: self.at.unwrap_or_else(now_sec) + self.ttl,
233 signature: Box::new([0; 64]),
234 };
235 *value.signature = network.sign_tl(&value);
236 value
237 }
238
239 fn make_unsigned_value_ref(&self) -> PeerValueRef<'_> {
240 PeerValueRef {
241 key: PeerValueKeyRef {
242 name: self.inner.name,
243 peer_id: &self.inner.inner.local_id,
244 },
245 data: &self.data,
246 expires_at: self.at.unwrap_or_else(now_sec) + self.ttl,
247 signature: &[0; 64],
248 }
249 }
250}
251
252impl<'a> std::ops::Deref for DhtQueryWithDataBuilder<'a> {
253 type Target = DhtQueryBuilder<'a>;
254
255 #[inline]
256 fn deref(&self) -> &Self::Target {
257 &self.inner
258 }
259}
260
261impl std::ops::DerefMut for DhtQueryWithDataBuilder<'_> {
262 #[inline]
263 fn deref_mut(&mut self) -> &mut Self::Target {
264 &mut self.inner
265 }
266}
267
268pub struct DhtServiceBackgroundTasks {
269 inner: Arc<DhtInner>,
270}
271
272impl DhtServiceBackgroundTasks {
273 pub fn spawn(self, network: &Network) {
274 self.inner
275 .start_background_tasks(Network::downgrade(network));
276 }
277}
278
279pub struct DhtServiceBuilder {
280 local_id: PeerId,
281 config: Option<DhtConfig>,
282}
283
284impl DhtServiceBuilder {
285 pub fn with_config(mut self, config: DhtConfig) -> Self {
286 self.config = Some(config);
287 self
288 }
289
290 pub fn build(self) -> (DhtServiceBackgroundTasks, DhtService) {
291 let config = self.config.unwrap_or_default();
292
293 let storage = {
294 let mut builder = Storage::builder()
295 .with_max_capacity(config.max_storage_capacity)
296 .with_max_ttl(config.max_stored_value_ttl);
297
298 if let Some(time_to_idle) = config.storage_item_time_to_idle {
299 builder = builder.with_max_idle(time_to_idle);
300 }
301
302 builder.build()
303 };
304
305 let (announced_peers, _) = broadcast::channel(config.announced_peers_channel_capacity);
306
307 let inner = Arc::new(DhtInner {
308 local_id: self.local_id,
309 routing_table: Mutex::new(HandlesRoutingTable::new(self.local_id)),
310 storage,
311 local_peer_info: Mutex::new(None),
312 config,
313 announced_peers,
314 find_value_queries: Default::default(),
315 peer_added: Arc::new(Default::default()),
316 });
317
318 let background_tasks = DhtServiceBackgroundTasks {
319 inner: inner.clone(),
320 };
321
322 (background_tasks, DhtService(inner))
323 }
324}
325
326#[derive(Clone)]
327#[repr(transparent)]
328pub struct DhtService(Arc<DhtInner>);
329
330impl DhtService {
331 #[inline]
332 fn wrap(inner: &Arc<DhtInner>) -> &Self {
333 unsafe { &*(inner as *const Arc<DhtInner>).cast::<Self>() }
335 }
336
337 pub fn builder(local_id: PeerId) -> DhtServiceBuilder {
338 DhtServiceBuilder {
339 local_id,
340 config: None,
341 }
342 }
343
344 pub fn local_id(&self) -> &PeerId {
345 &self.0.local_id
346 }
347
348 pub fn make_client(&self, network: &Network) -> DhtClient {
349 DhtClient {
350 inner: self.0.clone(),
351 network: network.clone(),
352 }
353 }
354
355 pub fn make_peer_resolver(&self) -> PeerResolverBuilder {
356 PeerResolver::builder(self.clone())
357 }
358
359 pub fn has_peer(&self, peer_id: &PeerId) -> bool {
360 self.0.routing_table.lock().unwrap().contains(peer_id)
361 }
362
363 pub fn find_local_closest(&self, key: &[u8; 32], count: usize) -> Vec<Arc<PeerInfo>> {
364 self.0.routing_table.lock().unwrap().closest(key, count)
365 }
366
367 pub fn store_value_locally(&self, value: &ValueRef<'_>) -> Result<bool, StorageError> {
368 self.0.store_value_locally(value)
369 }
370
371 pub fn insert_merger(
372 &self,
373 group_id: &[u8; 32],
374 merger: Arc<dyn DhtValueMerger>,
375 ) -> Option<Arc<dyn DhtValueMerger>> {
376 self.0.storage.insert_merger(group_id, merger)
377 }
378
379 pub fn remove_merger(&self, group_id: &[u8; 32]) -> Option<Arc<dyn DhtValueMerger>> {
380 self.0.storage.remove_merger(group_id)
381 }
382
383 pub fn peer_added(&self) -> &Arc<Notify> {
384 &self.0.peer_added
385 }
386}
387
388impl Service<ServiceRequest> for DhtService {
389 type QueryResponse = Response;
390 type OnQueryFuture = futures_util::future::Ready<Option<Self::QueryResponse>>;
391 type OnMessageFuture = futures_util::future::Ready<()>;
392
393 #[tracing::instrument(
394 level = "debug",
395 name = "on_dht_query",
396 skip_all,
397 fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
398 )]
399 fn on_query(&self, req: ServiceRequest) -> Self::OnQueryFuture {
400 metrics::counter!(METRIC_IN_REQ_TOTAL).increment(1);
401
402 let (constructor, body) = match self.0.try_handle_prefix(&req) {
403 Ok(rest) => rest,
404 Err(e) => {
405 tracing::debug!("failed to deserialize query: {e}");
406 metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
407 return futures_util::future::ready(None);
408 }
409 };
410
411 let response = crate::match_tl_request!(body, tag = constructor, {
412 rpc::FindNode as ref r => {
413 tracing::debug!(key = %PeerId::wrap(&r.key), k = r.k, "find_node");
414 metrics::counter!(METRIC_IN_REQ_FIND_NODE_TOTAL).increment(1);
415
416 let res = self.0.handle_find_node(r);
417 Some(tl_proto::serialize(res))
418 },
419 rpc::FindValue as ref r => {
420 tracing::debug!(key = %PeerId::wrap(&r.key), k = r.k, "find_value");
421 metrics::counter!(METRIC_IN_REQ_FIND_VALUE_TOTAL).increment(1);
422
423 let res = self.0.handle_find_value(r);
424 Some(tl_proto::serialize(res))
425 },
426 rpc::GetNodeInfo as _ => {
427 tracing::debug!("get_node_info");
428 metrics::counter!(METRIC_IN_REQ_GET_NODE_INFO_TOTAL).increment(1);
429
430 self.0.handle_get_node_info().map(tl_proto::serialize)
431 },
432 }, e => {
433 tracing::debug!("failed to deserialize query: {e}");
434 None
435 });
436
437 if response.is_none() {
438 metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
439 }
440
441 futures_util::future::ready(response.map(|body| Response {
442 version: Default::default(),
443 body: Bytes::from(body),
444 }))
445 }
446
447 #[tracing::instrument(
448 level = "debug",
449 name = "on_dht_message",
450 skip_all,
451 fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
452 )]
453 fn on_message(&self, req: ServiceRequest) -> Self::OnMessageFuture {
454 metrics::counter!(METRIC_IN_REQ_TOTAL).increment(1);
455
456 let (constructor, body) = match self.0.try_handle_prefix(&req) {
457 Ok(rest) => rest,
458 Err(e) => {
459 tracing::debug!("failed to deserialize message: {e}");
460 metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
461 return futures_util::future::ready(());
462 }
463 };
464
465 let mut has_error = false;
466 crate::match_tl_request!(body, tag = constructor, {
467 rpc::StoreRef<'_> as ref r => {
468 tracing::debug!("store");
469 metrics::counter!(METRIC_IN_REQ_STORE_TOTAL).increment(1);
470
471 if let Err(e) = self.0.handle_store(r) {
472 tracing::debug!("failed to store value: {e}");
473 has_error = true;
474 }
475 }
476 }, e => {
477 tracing::debug!("failed to deserialize message: {e}");
478 has_error = true;
479 });
480
481 if has_error {
482 metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
483 }
484
485 futures_util::future::ready(())
486 }
487}
488
489impl Routable for DhtService {
490 fn query_ids(&self) -> impl IntoIterator<Item = u32> {
491 [
492 rpc::WithPeerInfo::TL_ID,
493 rpc::FindNode::TL_ID,
494 rpc::FindValue::TL_ID,
495 rpc::GetNodeInfo::TL_ID,
496 ]
497 }
498
499 fn message_ids(&self) -> impl IntoIterator<Item = u32> {
500 [rpc::WithPeerInfo::TL_ID, rpc::Store::TL_ID]
501 }
502}
503
504struct DhtInner {
505 local_id: PeerId,
506 routing_table: Mutex<HandlesRoutingTable>,
507 storage: Storage,
508 local_peer_info: Mutex<Option<PeerInfo>>,
509 config: DhtConfig,
510 announced_peers: broadcast::Sender<Arc<PeerInfo>>,
511 find_value_queries: QueryCache<Option<Box<Value>>>,
512 peer_added: Arc<Notify>,
513}
514
515impl DhtInner {
516 async fn find_value(
517 &self,
518 network: &Network,
519 key_hash: &[u8; 32],
520 mode: DhtQueryMode,
521 ) -> Option<Box<Value>> {
522 self.find_value_queries
523 .run(key_hash, || {
524 let query = Query::new(
525 network.clone(),
526 &self.routing_table.lock().unwrap(),
527 key_hash,
528 self.config.max_k,
529 mode,
530 );
531
532 Box::pin(query.find_value())
534 })
535 .await
536 }
537
538 async fn store_value(
539 &self,
540 network: &Network,
541 value: &ValueRef<'_>,
542 with_peer_info: bool,
543 ) -> Result<()> {
544 self.storage.insert(DhtValueSource::Local, value)?;
545
546 let local_peer_info = if with_peer_info {
547 let mut node_info = self.local_peer_info.lock().unwrap();
548 Some(
549 node_info
550 .get_or_insert_with(|| self.make_local_peer_info(network, now_sec()))
551 .clone(),
552 )
553 } else {
554 None
555 };
556
557 let query = StoreValue::new(
558 network.clone(),
559 &self.routing_table.lock().unwrap(),
560 value,
561 self.config.max_k,
562 local_peer_info.as_ref(),
563 );
564
565 query.run().await;
567 Ok(())
568 }
569
570 fn store_value_locally(&self, value: &ValueRef<'_>) -> Result<bool, StorageError> {
571 self.storage.insert(DhtValueSource::Local, value)
572 }
573
574 fn add_peer_info(&self, network: &Network, peer_info: Arc<PeerInfo>) -> bool {
576 if peer_info.id == self.local_id {
577 return false;
578 }
579
580 let mut routing_table = self.routing_table.lock().unwrap();
581 let added = routing_table.add(
582 peer_info.clone(),
583 self.config.max_k,
584 &self.config.max_peer_info_ttl,
585 |peer_info| network.known_peers().insert(peer_info, false).ok(),
586 );
587
588 if added {
589 self.peer_added.notify_waiters();
590 }
591
592 added
593 }
594
595 fn make_unsigned_peer_value<'a>(
596 &'a self,
597 name: PeerValueKeyName,
598 data: &'a [u8],
599 expires_at: u32,
600 ) -> PeerValueRef<'a> {
601 PeerValueRef {
602 key: PeerValueKeyRef {
603 name,
604 peer_id: &self.local_id,
605 },
606 data,
607 expires_at,
608 signature: &[0; 64],
609 }
610 }
611
612 fn make_local_peer_info(&self, network: &Network, now: u32) -> PeerInfo {
613 network.sign_peer_info(now, self.config.max_peer_info_ttl.as_secs() as u32)
614 }
615
616 fn try_handle_prefix<'a>(&self, req: &'a ServiceRequest) -> Result<(u32, &'a [u8])> {
617 let mut body = req.as_ref();
618 anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof);
619
620 let mut constructor = std::convert::identity(body).get_u32_le();
622
623 if constructor == rpc::WithPeerInfo::TL_ID {
624 metrics::counter!(METRIC_IN_REQ_WITH_PEER_INFO_TOTAL).increment(1);
625
626 let peer_info = rpc::WithPeerInfo::read_from(&mut body)?.peer_info;
627 anyhow::ensure!(
628 peer_info.id == req.metadata.peer_id,
629 "suggested peer ID does not belong to the sender"
630 );
631
632 anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof);
633 self.announced_peers.send(Arc::new(peer_info)).ok();
634
635 constructor = std::convert::identity(body).get_u32_le();
637 }
638
639 Ok((constructor, body))
640 }
641
642 fn handle_store(&self, req: &rpc::StoreRef<'_>) -> Result<bool, StorageError> {
643 self.storage.insert(DhtValueSource::Remote, &req.value)
644 }
645
646 fn handle_find_node(&self, req: &rpc::FindNode) -> NodeResponse {
647 let nodes = self
648 .routing_table
649 .lock()
650 .unwrap()
651 .closest(&req.key, (req.k as usize).min(self.config.max_k));
652
653 NodeResponse { nodes }
654 }
655
656 fn handle_find_value(&self, req: &rpc::FindValue) -> ValueResponseRaw {
657 if let Some(value) = self.storage.get(&req.key) {
658 ValueResponseRaw::Found(value)
659 } else {
660 let nodes = self
661 .routing_table
662 .lock()
663 .unwrap()
664 .closest(&req.key, (req.k as usize).min(self.config.max_k));
665
666 ValueResponseRaw::NotFound(nodes)
667 }
668 }
669
670 fn handle_get_node_info(&self) -> Option<NodeInfoResponse> {
671 self.local_peer_info
672 .lock()
673 .unwrap()
674 .clone()
675 .map(|info| NodeInfoResponse { info })
676 }
677}
678
679fn random_key_at_distance(from: &PeerId, distance: usize, rng: &mut impl RngCore) -> PeerId {
680 let distance = MAX_XOR_DISTANCE - distance;
681
682 let mut result = *from;
683
684 let byte_offset = distance / 8;
685 rng.fill_bytes(&mut result.0[byte_offset..]);
686
687 let bit_offset = distance % 8;
688 if bit_offset != 0 {
689 let mask = 0xff >> bit_offset;
690 result.0[byte_offset] ^= (result.0[byte_offset] ^ from.0[byte_offset]) & !mask;
691 }
692
693 result
694}
695
696pub fn xor_distance(left: &PeerId, right: &PeerId) -> usize {
697 for (i, (left, right)) in std::iter::zip(left.0.chunks(8), right.0.chunks(8)).enumerate() {
698 let left = u64::from_be_bytes(left.try_into().unwrap());
699 let right = u64::from_be_bytes(right.try_into().unwrap());
700 let diff = left ^ right;
701 if diff != 0 {
702 return MAX_XOR_DISTANCE - (i * 64 + diff.leading_zeros() as usize);
703 }
704 }
705
706 0
707}
708
709const MAX_XOR_DISTANCE: usize = 256;
710
711#[derive(Debug, thiserror::Error)]
712pub enum FindValueError {
713 #[error("failed to deserialize value: {0}")]
714 InvalidData(#[from] tl_proto::TlError),
715 #[error("value not found")]
716 NotFound,
717}
718
719#[cfg(test)]
720mod tests {
721 use super::*;
722
723 #[test]
724 fn proper_random_keys() {
725 let peer_id = rand::random();
726 let random_id = random_key_at_distance(&peer_id, 20, &mut rand::rng());
727 println!("{peer_id}");
728 println!("{random_id}");
729
730 let distance = xor_distance(&peer_id, &random_id);
731 println!("{distance}");
732 assert!(distance <= 23);
733 }
734}