1use std::sync::{Arc, Mutex};
2
3use anyhow::Result;
4use bytes::{Buf, Bytes};
5use rand::RngCore;
6use tl_proto::TlRead;
7use tokio::sync::{Notify, broadcast};
8use tycho_util::realloc_box_enum;
9use tycho_util::time::now_sec;
10
11pub use self::config::DhtConfig;
12pub use self::peer_resolver::{
13 PeerResolver, PeerResolverBuilder, PeerResolverConfig, PeerResolverHandle,
14};
15pub use self::query::DhtQueryMode;
16use self::query::{Query, QueryCache, StoreValue};
17use self::routing::HandlesRoutingTable;
18use self::storage::Storage;
19pub use self::storage::{DhtValueMerger, DhtValueSource, StorageError};
20use crate::network::Network;
21use crate::proto::dht::{
22 NodeInfoResponse, NodeResponse, PeerValue, PeerValueKey, PeerValueKeyName, PeerValueKeyRef,
23 PeerValueRef, Value, ValueRef, ValueResponseRaw, rpc,
24};
25use crate::types::{PeerId, PeerInfo, Request, Response, Service, ServiceRequest};
26use crate::util::{NetworkExt, Routable};
27
28mod background_tasks;
29mod config;
30mod peer_resolver;
31mod query;
32mod routing;
33mod storage;
34
35const METRIC_IN_REQ_TOTAL: &str = "tycho_net_dht_in_req_total";
37const METRIC_IN_REQ_FAIL_TOTAL: &str = "tycho_net_dht_in_req_fail_total";
38
39const METRIC_IN_REQ_WITH_PEER_INFO_TOTAL: &str = "tycho_net_dht_in_req_with_peer_info_total";
40const METRIC_IN_REQ_FIND_NODE_TOTAL: &str = "tycho_net_dht_in_req_find_node_total";
41const METRIC_IN_REQ_FIND_VALUE_TOTAL: &str = "tycho_net_dht_in_req_find_value_total";
42const METRIC_IN_REQ_GET_NODE_INFO_TOTAL: &str = "tycho_net_dht_in_req_get_node_info_total";
43const METRIC_IN_REQ_STORE_TOTAL: &str = "tycho_net_dht_in_req_store_value_total";
44
45#[derive(Clone)]
46pub struct DhtClient {
47 inner: Arc<DhtInner>,
48 network: Network,
49}
50
51impl DhtClient {
52 #[inline]
53 pub fn network(&self) -> &Network {
54 &self.network
55 }
56
57 #[inline]
58 pub fn service(&self) -> &DhtService {
59 DhtService::wrap(&self.inner)
60 }
61
62 pub fn add_peer(&self, peer: Arc<PeerInfo>) -> Result<bool> {
63 anyhow::ensure!(peer.verify(now_sec()), "invalid peer info");
64 let added = self.inner.add_peer_info(&self.network, peer);
65 Ok(added)
66 }
67
68 pub async fn get_node_info(&self, peer_id: &PeerId) -> Result<PeerInfo> {
69 let res = self
70 .network
71 .query(peer_id, Request::from_tl(rpc::GetNodeInfo))
72 .await?;
73 let NodeInfoResponse { info } = res.parse_tl()?;
74 Ok(info)
75 }
76
77 pub fn entry(&self, name: PeerValueKeyName) -> DhtQueryBuilder<'_> {
78 DhtQueryBuilder {
79 inner: &self.inner,
80 network: &self.network,
81 name,
82 idx: 0,
83 }
84 }
85
86 pub async fn find_value(&self, key_hash: &[u8; 32], mode: DhtQueryMode) -> Option<Box<Value>> {
90 self.inner.find_value(&self.network, key_hash, mode).await
91 }
92}
93
94#[derive(Clone, Copy)]
95pub struct DhtQueryBuilder<'a> {
96 inner: &'a DhtInner,
97 network: &'a Network,
98 name: PeerValueKeyName,
99 idx: u32,
100}
101
102impl<'a> DhtQueryBuilder<'a> {
103 #[inline]
104 pub fn with_idx(&mut self, idx: u32) -> &mut Self {
105 self.idx = idx;
106 self
107 }
108
109 pub async fn find_value<T>(&self, peer_id: &PeerId) -> Result<T, FindValueError>
110 where
111 for<'tl> T: tl_proto::TlRead<'tl>,
112 {
113 let key_hash = tl_proto::hash(PeerValueKeyRef {
114 name: self.name,
115 peer_id,
116 });
117
118 match self
119 .inner
120 .find_value(self.network, &key_hash, DhtQueryMode::Closest)
121 .await
122 {
123 Some(value) => match value.as_ref() {
124 Value::Peer(value) => {
125 tl_proto::deserialize(&value.data).map_err(FindValueError::InvalidData)
126 }
127 Value::Merged(_) => Err(FindValueError::InvalidData(
128 tl_proto::TlError::UnknownConstructor,
129 )),
130 },
131 None => Err(FindValueError::NotFound),
132 }
133 }
134
135 pub async fn find_peer_value_raw(
136 &self,
137 peer_id: &PeerId,
138 ) -> Result<Box<PeerValue>, FindValueError> {
139 let key_hash = tl_proto::hash(PeerValueKeyRef {
140 name: self.name,
141 peer_id,
142 });
143
144 match self
145 .inner
146 .find_value(self.network, &key_hash, DhtQueryMode::Closest)
147 .await
148 {
149 Some(value) => {
150 realloc_box_enum!(value, {
151 Value::Peer(value) => Box::new(value) => Ok(value),
152 Value::Merged(_) => Err(FindValueError::InvalidData(
153 tl_proto::TlError::UnknownConstructor,
154 )),
155 })
156 }
157 None => Err(FindValueError::NotFound),
158 }
159 }
160
161 pub fn with_data<T>(&self, data: T) -> DhtQueryWithDataBuilder<'a>
162 where
163 T: tl_proto::TlWrite,
164 {
165 DhtQueryWithDataBuilder {
166 inner: *self,
167 data: tl_proto::serialize(&data),
168 at: None,
169 ttl: self.inner.config.max_stored_value_ttl.as_secs() as _,
170 with_peer_info: false,
171 }
172 }
173}
174
175pub struct DhtQueryWithDataBuilder<'a> {
176 inner: DhtQueryBuilder<'a>,
177 data: Vec<u8>,
178 at: Option<u32>,
179 ttl: u32,
180 with_peer_info: bool,
181}
182
183impl DhtQueryWithDataBuilder<'_> {
184 pub fn with_time(&mut self, at: u32) -> &mut Self {
185 self.at = Some(at);
186 self
187 }
188
189 pub fn with_ttl(&mut self, ttl: u32) -> &mut Self {
190 self.ttl = ttl;
191 self
192 }
193
194 pub fn with_peer_info(&mut self, with_peer_info: bool) -> &mut Self {
195 self.with_peer_info = with_peer_info;
196 self
197 }
198
199 pub async fn store(&self) -> Result<()> {
200 let dht = self.inner.inner;
201 let network = self.inner.network;
202
203 let mut value = self.make_unsigned_value_ref();
204 let signature = network.sign_tl(&value);
205 value.signature = &signature;
206
207 dht.store_value(network, &ValueRef::Peer(value), self.with_peer_info)
208 .await
209 }
210
211 pub fn store_locally(&self) -> Result<bool, StorageError> {
212 let dht = self.inner.inner;
213 let network = self.inner.network;
214
215 let mut value = self.make_unsigned_value_ref();
216 let signature = network.sign_tl(&value);
217 value.signature = &signature;
218
219 dht.store_value_locally(&ValueRef::Peer(value))
220 }
221
222 pub fn into_signed_value(self) -> PeerValue {
223 let dht = self.inner.inner;
224 let network = self.inner.network;
225
226 let mut value = PeerValue {
227 key: PeerValueKey {
228 name: self.name,
229 peer_id: dht.local_id,
230 },
231 data: self.data.into_boxed_slice(),
232 expires_at: self.at.unwrap_or_else(now_sec) + self.ttl,
233 signature: Box::new([0; 64]),
234 };
235 *value.signature = network.sign_tl(&value);
236 value
237 }
238
239 fn make_unsigned_value_ref(&self) -> PeerValueRef<'_> {
240 PeerValueRef {
241 key: PeerValueKeyRef {
242 name: self.inner.name,
243 peer_id: &self.inner.inner.local_id,
244 },
245 data: &self.data,
246 expires_at: self.at.unwrap_or_else(now_sec) + self.ttl,
247 signature: &[0; 64],
248 }
249 }
250}
251
252impl<'a> std::ops::Deref for DhtQueryWithDataBuilder<'a> {
253 type Target = DhtQueryBuilder<'a>;
254
255 #[inline]
256 fn deref(&self) -> &Self::Target {
257 &self.inner
258 }
259}
260
261impl std::ops::DerefMut for DhtQueryWithDataBuilder<'_> {
262 #[inline]
263 fn deref_mut(&mut self) -> &mut Self::Target {
264 &mut self.inner
265 }
266}
267
268pub struct DhtServiceBackgroundTasks {
269 inner: Arc<DhtInner>,
270}
271
272impl DhtServiceBackgroundTasks {
273 pub fn spawn(self, network: &Network) {
274 self.inner
275 .start_background_tasks(Network::downgrade(network));
276 }
277}
278
279pub struct DhtServiceBuilder {
280 local_id: PeerId,
281 config: Option<DhtConfig>,
282}
283
284impl DhtServiceBuilder {
285 pub fn with_config(mut self, config: DhtConfig) -> Self {
286 self.config = Some(config);
287 self
288 }
289
290 pub fn build(self) -> (DhtServiceBackgroundTasks, DhtService) {
291 let config = self.config.unwrap_or_default();
292
293 let storage = {
294 let mut builder = Storage::builder()
295 .with_max_capacity(config.max_storage_capacity)
296 .with_max_ttl(config.max_stored_value_ttl);
297
298 if let Some(time_to_idle) = config.storage_item_time_to_idle {
299 builder = builder.with_max_idle(time_to_idle);
300 }
301
302 builder.build()
303 };
304
305 let (announced_peers, _) = broadcast::channel(config.announced_peers_channel_capacity);
306
307 let inner = Arc::new(DhtInner {
308 local_id: self.local_id,
309 routing_table: Mutex::new(HandlesRoutingTable::new(self.local_id)),
310 storage,
311 local_peer_info: Mutex::new(None),
312 config,
313 announced_peers,
314 find_value_queries: Default::default(),
315 peer_added: Arc::new(Default::default()),
316 });
317
318 let background_tasks = DhtServiceBackgroundTasks {
319 inner: inner.clone(),
320 };
321
322 (background_tasks, DhtService(inner))
323 }
324}
325
326#[derive(Clone)]
327#[repr(transparent)]
328pub struct DhtService(Arc<DhtInner>);
329
330impl DhtService {
331 #[inline]
332 fn wrap(inner: &Arc<DhtInner>) -> &Self {
333 unsafe { &*(inner as *const Arc<DhtInner>).cast::<Self>() }
335 }
336
337 pub fn builder(local_id: PeerId) -> DhtServiceBuilder {
338 DhtServiceBuilder {
339 local_id,
340 config: None,
341 }
342 }
343
344 pub fn local_id(&self) -> &PeerId {
345 &self.0.local_id
346 }
347
348 pub fn make_client(&self, network: &Network) -> DhtClient {
349 DhtClient {
350 inner: self.0.clone(),
351 network: network.clone(),
352 }
353 }
354
355 pub fn make_peer_resolver(&self) -> PeerResolverBuilder {
356 PeerResolver::builder(self.clone())
357 }
358
359 pub fn has_peer(&self, peer_id: &PeerId) -> bool {
360 self.0.routing_table.lock().unwrap().contains(peer_id)
361 }
362
363 pub fn store_value_locally(&self, value: &ValueRef<'_>) -> Result<bool, StorageError> {
364 self.0.store_value_locally(value)
365 }
366
367 pub fn insert_merger(
368 &self,
369 group_id: &[u8; 32],
370 merger: Arc<dyn DhtValueMerger>,
371 ) -> Option<Arc<dyn DhtValueMerger>> {
372 self.0.storage.insert_merger(group_id, merger)
373 }
374
375 pub fn remove_merger(&self, group_id: &[u8; 32]) -> Option<Arc<dyn DhtValueMerger>> {
376 self.0.storage.remove_merger(group_id)
377 }
378
379 pub fn peer_added(&self) -> &Arc<Notify> {
380 &self.0.peer_added
381 }
382}
383
384impl Service<ServiceRequest> for DhtService {
385 type QueryResponse = Response;
386 type OnQueryFuture = futures_util::future::Ready<Option<Self::QueryResponse>>;
387 type OnMessageFuture = futures_util::future::Ready<()>;
388
389 #[tracing::instrument(
390 level = "debug",
391 name = "on_dht_query",
392 skip_all,
393 fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
394 )]
395 fn on_query(&self, req: ServiceRequest) -> Self::OnQueryFuture {
396 metrics::counter!(METRIC_IN_REQ_TOTAL).increment(1);
397
398 let (constructor, body) = match self.0.try_handle_prefix(&req) {
399 Ok(rest) => rest,
400 Err(e) => {
401 tracing::debug!("failed to deserialize query: {e}");
402 metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
403 return futures_util::future::ready(None);
404 }
405 };
406
407 let response = crate::match_tl_request!(body, tag = constructor, {
408 rpc::FindNode as ref r => {
409 tracing::debug!(key = %PeerId::wrap(&r.key), k = r.k, "find_node");
410 metrics::counter!(METRIC_IN_REQ_FIND_NODE_TOTAL).increment(1);
411
412 let res = self.0.handle_find_node(r);
413 Some(tl_proto::serialize(res))
414 },
415 rpc::FindValue as ref r => {
416 tracing::debug!(key = %PeerId::wrap(&r.key), k = r.k, "find_value");
417 metrics::counter!(METRIC_IN_REQ_FIND_VALUE_TOTAL).increment(1);
418
419 let res = self.0.handle_find_value(r);
420 Some(tl_proto::serialize(res))
421 },
422 rpc::GetNodeInfo as _ => {
423 tracing::debug!("get_node_info");
424 metrics::counter!(METRIC_IN_REQ_GET_NODE_INFO_TOTAL).increment(1);
425
426 self.0.handle_get_node_info().map(tl_proto::serialize)
427 },
428 }, e => {
429 tracing::debug!("failed to deserialize query: {e}");
430 None
431 });
432
433 if response.is_none() {
434 metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
435 }
436
437 futures_util::future::ready(response.map(|body| Response {
438 version: Default::default(),
439 body: Bytes::from(body),
440 }))
441 }
442
443 #[tracing::instrument(
444 level = "debug",
445 name = "on_dht_message",
446 skip_all,
447 fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
448 )]
449 fn on_message(&self, req: ServiceRequest) -> Self::OnMessageFuture {
450 metrics::counter!(METRIC_IN_REQ_TOTAL).increment(1);
451
452 let (constructor, body) = match self.0.try_handle_prefix(&req) {
453 Ok(rest) => rest,
454 Err(e) => {
455 tracing::debug!("failed to deserialize message: {e}");
456 metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
457 return futures_util::future::ready(());
458 }
459 };
460
461 let mut has_error = false;
462 crate::match_tl_request!(body, tag = constructor, {
463 rpc::StoreRef<'_> as ref r => {
464 tracing::debug!("store");
465 metrics::counter!(METRIC_IN_REQ_STORE_TOTAL).increment(1);
466
467 if let Err(e) = self.0.handle_store(r) {
468 tracing::debug!("failed to store value: {e}");
469 has_error = true;
470 }
471 }
472 }, e => {
473 tracing::debug!("failed to deserialize message: {e}");
474 has_error = true;
475 });
476
477 if has_error {
478 metrics::counter!(METRIC_IN_REQ_FAIL_TOTAL).increment(1);
479 }
480
481 futures_util::future::ready(())
482 }
483}
484
485impl Routable for DhtService {
486 fn query_ids(&self) -> impl IntoIterator<Item = u32> {
487 [
488 rpc::WithPeerInfo::TL_ID,
489 rpc::FindNode::TL_ID,
490 rpc::FindValue::TL_ID,
491 rpc::GetNodeInfo::TL_ID,
492 ]
493 }
494
495 fn message_ids(&self) -> impl IntoIterator<Item = u32> {
496 [rpc::WithPeerInfo::TL_ID, rpc::Store::TL_ID]
497 }
498}
499
500struct DhtInner {
501 local_id: PeerId,
502 routing_table: Mutex<HandlesRoutingTable>,
503 storage: Storage,
504 local_peer_info: Mutex<Option<PeerInfo>>,
505 config: DhtConfig,
506 announced_peers: broadcast::Sender<Arc<PeerInfo>>,
507 find_value_queries: QueryCache<Option<Box<Value>>>,
508 peer_added: Arc<Notify>,
509}
510
511impl DhtInner {
512 async fn find_value(
513 &self,
514 network: &Network,
515 key_hash: &[u8; 32],
516 mode: DhtQueryMode,
517 ) -> Option<Box<Value>> {
518 self.find_value_queries
519 .run(key_hash, || {
520 let query = Query::new(
521 network.clone(),
522 &self.routing_table.lock().unwrap(),
523 key_hash,
524 self.config.max_k,
525 mode,
526 );
527
528 Box::pin(query.find_value())
530 })
531 .await
532 }
533
534 async fn store_value(
535 &self,
536 network: &Network,
537 value: &ValueRef<'_>,
538 with_peer_info: bool,
539 ) -> Result<()> {
540 self.storage.insert(DhtValueSource::Local, value)?;
541
542 let local_peer_info = if with_peer_info {
543 let mut node_info = self.local_peer_info.lock().unwrap();
544 Some(
545 node_info
546 .get_or_insert_with(|| self.make_local_peer_info(network, now_sec()))
547 .clone(),
548 )
549 } else {
550 None
551 };
552
553 let query = StoreValue::new(
554 network.clone(),
555 &self.routing_table.lock().unwrap(),
556 value,
557 self.config.max_k,
558 local_peer_info.as_ref(),
559 );
560
561 query.run().await;
563 Ok(())
564 }
565
566 fn store_value_locally(&self, value: &ValueRef<'_>) -> Result<bool, StorageError> {
567 self.storage.insert(DhtValueSource::Local, value)
568 }
569
570 fn add_peer_info(&self, network: &Network, peer_info: Arc<PeerInfo>) -> bool {
572 if peer_info.id == self.local_id {
573 return false;
574 }
575
576 let mut routing_table = self.routing_table.lock().unwrap();
577 let added = routing_table.add(
578 peer_info.clone(),
579 self.config.max_k,
580 &self.config.max_peer_info_ttl,
581 |peer_info| network.known_peers().insert(peer_info, false).ok(),
582 );
583
584 if added {
585 self.peer_added.notify_waiters();
586 }
587
588 added
589 }
590
591 fn make_unsigned_peer_value<'a>(
592 &'a self,
593 name: PeerValueKeyName,
594 data: &'a [u8],
595 expires_at: u32,
596 ) -> PeerValueRef<'a> {
597 PeerValueRef {
598 key: PeerValueKeyRef {
599 name,
600 peer_id: &self.local_id,
601 },
602 data,
603 expires_at,
604 signature: &[0; 64],
605 }
606 }
607
608 fn make_local_peer_info(&self, network: &Network, now: u32) -> PeerInfo {
609 network.sign_peer_info(now, self.config.max_peer_info_ttl.as_secs() as u32)
610 }
611
612 fn try_handle_prefix<'a>(&self, req: &'a ServiceRequest) -> Result<(u32, &'a [u8])> {
613 let mut body = req.as_ref();
614 anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof);
615
616 let mut constructor = std::convert::identity(body).get_u32_le();
618
619 if constructor == rpc::WithPeerInfo::TL_ID {
620 metrics::counter!(METRIC_IN_REQ_WITH_PEER_INFO_TOTAL).increment(1);
621
622 let peer_info = rpc::WithPeerInfo::read_from(&mut body)?.peer_info;
623 anyhow::ensure!(
624 peer_info.id == req.metadata.peer_id,
625 "suggested peer ID does not belong to the sender"
626 );
627
628 anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof);
629 self.announced_peers.send(Arc::new(peer_info)).ok();
630
631 constructor = std::convert::identity(body).get_u32_le();
633 }
634
635 Ok((constructor, body))
636 }
637
638 fn handle_store(&self, req: &rpc::StoreRef<'_>) -> Result<bool, StorageError> {
639 self.storage.insert(DhtValueSource::Remote, &req.value)
640 }
641
642 fn handle_find_node(&self, req: &rpc::FindNode) -> NodeResponse {
643 let nodes = self
644 .routing_table
645 .lock()
646 .unwrap()
647 .closest(&req.key, (req.k as usize).min(self.config.max_k));
648
649 NodeResponse { nodes }
650 }
651
652 fn handle_find_value(&self, req: &rpc::FindValue) -> ValueResponseRaw {
653 if let Some(value) = self.storage.get(&req.key) {
654 ValueResponseRaw::Found(value)
655 } else {
656 let nodes = self
657 .routing_table
658 .lock()
659 .unwrap()
660 .closest(&req.key, (req.k as usize).min(self.config.max_k));
661
662 ValueResponseRaw::NotFound(nodes)
663 }
664 }
665
666 fn handle_get_node_info(&self) -> Option<NodeInfoResponse> {
667 self.local_peer_info
668 .lock()
669 .unwrap()
670 .clone()
671 .map(|info| NodeInfoResponse { info })
672 }
673}
674
675fn random_key_at_distance(from: &PeerId, distance: usize, rng: &mut impl RngCore) -> PeerId {
676 let distance = MAX_XOR_DISTANCE - distance;
677
678 let mut result = *from;
679
680 let byte_offset = distance / 8;
681 rng.fill_bytes(&mut result.0[byte_offset..]);
682
683 let bit_offset = distance % 8;
684 if bit_offset != 0 {
685 let mask = 0xff >> bit_offset;
686 result.0[byte_offset] ^= (result.0[byte_offset] ^ from.0[byte_offset]) & !mask;
687 }
688
689 result
690}
691
692pub fn xor_distance(left: &PeerId, right: &PeerId) -> usize {
693 for (i, (left, right)) in std::iter::zip(left.0.chunks(8), right.0.chunks(8)).enumerate() {
694 let left = u64::from_be_bytes(left.try_into().unwrap());
695 let right = u64::from_be_bytes(right.try_into().unwrap());
696 let diff = left ^ right;
697 if diff != 0 {
698 return MAX_XOR_DISTANCE - (i * 64 + diff.leading_zeros() as usize);
699 }
700 }
701
702 0
703}
704
705const MAX_XOR_DISTANCE: usize = 256;
706
707#[derive(Debug, thiserror::Error)]
708pub enum FindValueError {
709 #[error("failed to deserialize value: {0}")]
710 InvalidData(#[from] tl_proto::TlError),
711 #[error("value not found")]
712 NotFound,
713}
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718
719 #[test]
720 fn proper_random_keys() {
721 let peer_id = rand::random();
722 let random_id = random_key_at_distance(&peer_id, 20, &mut rand::rng());
723 println!("{peer_id}");
724 println!("{random_id}");
725
726 let distance = xor_distance(&peer_id, &random_id);
727 println!("{distance}");
728 assert!(distance <= 23);
729 }
730}