1use std::collections::hash_map;
2use std::sync::Arc;
3use std::time::Duration;
4
5use ahash::{HashMapExt, HashSetExt};
6use anyhow::Result;
7use bytes::Bytes;
8use futures_util::stream::FuturesUnordered;
9use futures_util::{Future, StreamExt};
10use tokio::sync::Semaphore;
11use tycho_util::futures::{JoinTask, Shared, WeakSharedHandle};
12use tycho_util::sync::{rayon_run, yield_on_complex};
13use tycho_util::time::now_sec;
14use tycho_util::{FastDashMap, FastHashMap, FastHashSet};
15
16use crate::dht::config::DhtConfig;
17use crate::dht::routing::{HandlesRoutingTable, SimpleRoutingTable};
18use crate::network::Network;
19use crate::proto::dht::{NodeResponse, Value, ValueRef, ValueResponse, rpc};
20use crate::types::{PeerId, PeerInfo, Request};
21use crate::util::NetworkExt;
22
23pub struct QueryCache<R> {
24 cache: FastDashMap<[u8; 32], WeakSpawnedFut<R>>,
25}
26
27impl<R> QueryCache<R> {
28 pub async fn run<F, Fut>(&self, target_id: &[u8; 32], f: F) -> R
29 where
30 R: Clone + Send + 'static,
31 F: FnOnce() -> Fut,
32 Fut: Future<Output = R> + Send + 'static,
33 {
34 use dashmap::mapref::entry::Entry;
35
36 let fut = match self.cache.entry(*target_id) {
37 Entry::Vacant(entry) => {
38 let fut = Shared::new(JoinTask::new(f()));
39 if let Some(weak) = fut.downgrade() {
40 entry.insert(weak);
41 }
42 fut
43 }
44 Entry::Occupied(mut entry) => {
45 if let Some(fut) = entry.get().upgrade() {
46 fut
47 } else {
48 let fut = Shared::new(JoinTask::new(f()));
49 match fut.downgrade() {
50 Some(weak) => entry.insert(weak),
51 None => entry.remove(),
52 };
53 fut
54 }
55 }
56 };
57
58 fn on_drop<R>(_key: &[u8; 32], value: &WeakSpawnedFut<R>) -> bool {
59 value.strong_count() == 0
60 }
61
62 let (output, is_last) = {
63 struct Guard<'a, R> {
64 target_id: &'a [u8; 32],
65 cache: &'a FastDashMap<[u8; 32], WeakSpawnedFut<R>>,
66 fut: Option<Shared<JoinTask<R>>>,
67 }
68
69 impl<R> Drop for Guard<'_, R> {
70 fn drop(&mut self) {
71 if self.fut.take().map(Shared::consume).unwrap_or_default() {
73 self.cache.remove_if(self.target_id, on_drop);
74 }
75 }
76 }
77
78 let mut guard = Guard {
80 target_id,
81 cache: &self.cache,
82 fut: None,
83 };
84 let fut = guard.fut.insert(fut);
85
86 fut.await
91 };
92
93 if is_last {
95 self.cache.remove_if(target_id, on_drop);
97 }
98
99 output
100 }
101}
102
103impl<R> Default for QueryCache<R> {
104 fn default() -> Self {
105 Self {
106 cache: Default::default(),
107 }
108 }
109}
110
111type WeakSpawnedFut<T> = WeakSharedHandle<JoinTask<T>>;
112
113#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
114pub enum DhtQueryMode {
115 #[default]
116 Closest,
117 Random,
118}
119
120pub struct Query {
121 network: Network,
122 candidates: SimpleRoutingTable,
123 max_k: usize,
124 timeout: Duration,
125}
126
127impl Query {
128 pub fn new(
129 network: Network,
130 routing_table: &HandlesRoutingTable,
131 target_id: &[u8; 32],
132 config: &DhtConfig,
133 mode: DhtQueryMode,
134 ) -> Self {
135 let mut candidates = SimpleRoutingTable::new(PeerId(*target_id));
136
137 let random_id;
138 let target_id_for_full = match mode {
139 DhtQueryMode::Closest => target_id,
140 DhtQueryMode::Random => {
141 random_id = rand::random();
142 &random_id
143 }
144 };
145
146 let max_k = config.max_k;
147 let timeout = config.request_timeout;
148
149 routing_table.visit_closest(target_id_for_full, max_k, |node| {
150 candidates.add(node.load_peer_info(), max_k, &Duration::MAX, Some);
151 });
152
153 Self {
154 network,
155 candidates,
156 max_k,
157 timeout,
158 }
159 }
160
161 fn local_id(&self) -> &[u8; 32] {
162 self.candidates.local_id.as_bytes()
163 }
164
165 #[tracing::instrument(skip_all)]
166 pub async fn find_value(mut self) -> Option<Box<Value>> {
167 let request_body = Bytes::from(tl_proto::serialize(rpc::FindValue {
169 key: *self.local_id(),
170 k: self.max_k as u32,
171 }));
172
173 let semaphore = Semaphore::new(MAX_PARALLEL_REQUESTS);
175 let mut futures = FuturesUnordered::new();
176 self.candidates
177 .visit_closest(self.local_id(), self.max_k, |node| {
178 futures.push(Self::visit::<ValueResponse>(
179 self.network.clone(),
180 node.clone(),
181 request_body.clone(),
182 &semaphore,
183 self.timeout,
184 ));
185 });
186
187 let mut visited = FastHashSet::new();
189 while let Some((node, res)) = futures.next().await {
190 match res {
191 Some(Ok(ValueResponse::Found(value))) => {
193 let mut signature_checked = false;
194 let is_valid =
195 value.verify_ext(now_sec(), self.local_id(), &mut signature_checked);
196 tracing::debug!(peer_id = %node.id, is_valid, "found value");
197
198 yield_on_complex(signature_checked).await;
199
200 if !is_valid {
201 continue;
203 }
204
205 return Some(value);
206 }
207 Some(Ok(ValueResponse::NotFound(nodes))) => {
209 let node_count = nodes.len();
210 let has_new = self
211 .update_candidates(now_sec(), self.max_k, nodes, &mut visited)
212 .await;
213 tracing::debug!(peer_id = %node.id, count = node_count, has_new, "received nodes");
214
215 if !has_new {
216 continue;
218 }
219
220 self.candidates
222 .visit_closest(self.local_id(), self.max_k, |node| {
223 if visited.contains(&node.id) {
224 return;
226 }
227 futures.push(Self::visit::<ValueResponse>(
228 self.network.clone(),
229 node.clone(),
230 request_body.clone(),
231 &semaphore,
232 self.timeout,
233 ));
234 });
235 }
236 Some(Err(e)) => {
238 tracing::warn!(peer_id = %node.id, "failed to query nodes: {e}");
239 }
240 None => {
242 tracing::warn!(peer_id = %node.id, "failed to query nodes: timeout");
243 }
244 }
245 }
246
247 None
249 }
250
251 #[tracing::instrument(skip_all)]
252 pub async fn find_peers(mut self, depth: Option<usize>) -> FastHashMap<PeerId, Arc<PeerInfo>> {
253 let request_body = Bytes::from(tl_proto::serialize(rpc::FindNode {
255 key: *self.local_id(),
256 k: self.max_k as u32,
257 }));
258
259 let semaphore = Semaphore::new(MAX_PARALLEL_REQUESTS);
261 let mut futures = FuturesUnordered::new();
262 self.candidates
263 .visit_closest(self.local_id(), self.max_k, |node| {
264 futures.push(Self::visit::<NodeResponse>(
265 self.network.clone(),
266 node.clone(),
267 request_body.clone(),
268 &semaphore,
269 self.timeout,
270 ));
271 });
272
273 let mut current_depth = 0;
275 let max_depth = depth.unwrap_or(usize::MAX);
276 let mut result = FastHashMap::<PeerId, Arc<PeerInfo>>::new();
277 while let Some((node, res)) = futures.next().await {
278 match res {
279 Some(Ok(NodeResponse { nodes })) => {
281 tracing::debug!(peer_id = %node.id, count = nodes.len(), "received nodes");
282 if !self
283 .update_candidates_full(now_sec(), self.max_k, nodes, &mut result)
284 .await
285 {
286 continue;
288 }
289
290 current_depth += 1;
291 if current_depth >= max_depth {
292 break;
294 }
295
296 self.candidates
298 .visit_closest(self.local_id(), self.max_k, |node| {
299 if result.contains_key(&node.id) {
300 return;
302 }
303 futures.push(Self::visit::<NodeResponse>(
304 self.network.clone(),
305 node.clone(),
306 request_body.clone(),
307 &semaphore,
308 self.timeout,
309 ));
310 });
311 }
312 Some(Err(e)) => {
314 tracing::warn!(peer_id = %node.id, "failed to query nodes: {e}");
315 }
316 None => {
318 tracing::warn!(peer_id = %node.id, "failed to query nodes: timeout");
319 }
320 }
321 }
322
323 result
325 }
326
327 async fn update_candidates(
328 &mut self,
329 now: u32,
330 max_k: usize,
331 nodes: Vec<Arc<PeerInfo>>,
332 visited: &mut FastHashSet<PeerId>,
333 ) -> bool {
334 let mut has_new = false;
335 process_only_valid(now, nodes, |node| {
336 if visited.insert(node.id) {
338 self.candidates.add(node, max_k, &Duration::MAX, Some);
339 has_new = true;
340 }
341 })
342 .await;
343
344 has_new
345 }
346
347 async fn update_candidates_full(
348 &mut self,
349 now: u32,
350 max_k: usize,
351 nodes: Vec<Arc<PeerInfo>>,
352 visited: &mut FastHashMap<PeerId, Arc<PeerInfo>>,
353 ) -> bool {
354 let mut has_new = false;
355 process_only_valid(now, nodes, |node| {
356 match visited.entry(node.id) {
357 hash_map::Entry::Vacant(entry) => {
359 let node = entry.insert(node).clone();
360 self.candidates.add(node, max_k, &Duration::MAX, Some);
361 has_new = true;
362 }
363 hash_map::Entry::Occupied(mut entry) => {
365 if entry.get().created_at < node.created_at {
366 *entry.get_mut() = node;
367 }
368 }
369 }
370 })
371 .await;
372
373 has_new
374 }
375
376 async fn visit<T>(
377 network: Network,
378 node: Arc<PeerInfo>,
379 request_body: Bytes,
380 semaphore: &Semaphore,
381 timeout: Duration,
382 ) -> (Arc<PeerInfo>, Option<Result<T>>)
383 where
384 for<'a> T: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
385 {
386 let Ok(_permit) = semaphore.acquire().await else {
387 return (node, None);
388 };
389
390 let req = network.query(&node.id, Request {
391 version: Default::default(),
392 body: request_body.clone(),
393 });
394
395 let res = match tokio::time::timeout(timeout, req).await {
396 Ok(res) => {
397 Some(res.and_then(|res| tl_proto::deserialize::<T>(&res.body).map_err(Into::into)))
398 }
399 Err(_) => None,
400 };
401
402 (node, res)
403 }
404}
405
406pub struct StoreValue<F = ()> {
407 futures: FuturesUnordered<F>,
408}
409
410impl StoreValue<()> {
411 pub fn new(
412 network: Network,
413 routing_table: &HandlesRoutingTable,
414 value: &ValueRef<'_>,
415 config: &DhtConfig,
416 local_peer_info: Option<&PeerInfo>,
417 ) -> StoreValue<impl Future<Output = (Arc<PeerInfo>, Option<Result<()>>)> + Send + use<>> {
418 let key_hash = match value {
419 ValueRef::Peer(value) => tl_proto::hash(&value.key),
420 ValueRef::Merged(value) => tl_proto::hash(&value.key),
421 };
422
423 let request_body = Bytes::from(match local_peer_info {
424 Some(peer_info) => tl_proto::serialize((
425 rpc::WithPeerInfo::wrap(peer_info),
426 rpc::StoreRef::wrap(value),
427 )),
428 None => tl_proto::serialize(rpc::StoreRef::wrap(value)),
429 });
430
431 let semaphore = Arc::new(Semaphore::new(10));
432 let futures = futures_util::stream::FuturesUnordered::new();
433 routing_table.visit_closest(&key_hash, config.max_k, |node| {
434 futures.push(Self::visit(
435 network.clone(),
436 node.load_peer_info(),
437 request_body.clone(),
438 semaphore.clone(),
439 config.request_timeout,
440 ));
441 });
442
443 StoreValue { futures }
444 }
445
446 async fn visit(
447 network: Network,
448 node: Arc<PeerInfo>,
449 request_body: Bytes,
450 semaphore: Arc<Semaphore>,
451 timeout: Duration,
452 ) -> (Arc<PeerInfo>, Option<Result<()>>) {
453 let Ok(_permit) = semaphore.acquire().await else {
454 return (node, None);
455 };
456
457 let req = network.send(&node.id, Request {
458 version: Default::default(),
459 body: request_body.clone(),
460 });
461
462 let res = (tokio::time::timeout(timeout, req).await).ok();
463
464 (node, res)
465 }
466}
467
468impl<T: Future<Output = (Arc<PeerInfo>, Option<Result<()>>)> + Send> StoreValue<T> {
469 #[tracing::instrument(level = "debug", skip_all, name = "store_value")]
470 pub async fn run(mut self) {
471 while let Some((node, res)) = self.futures.next().await {
472 match res {
473 Some(Ok(())) => {
474 tracing::debug!(peer_id = %node.id, "value stored");
475 }
476 Some(Err(e)) => {
477 tracing::warn!(peer_id = %node.id, "failed to store value: {e}");
478 }
479 None => {
481 tracing::warn!(peer_id = %node.id, "failed to store value: timeout");
482 }
483 }
484 }
485 }
486}
487
488async fn process_only_valid<F>(now: u32, mut nodes: Vec<Arc<PeerInfo>>, mut handle_valid_node: F)
489where
490 F: FnMut(Arc<PeerInfo>) + Send,
491{
492 const SPAWN_THRESHOLD: usize = 4;
493
494 if nodes.len() > SPAWN_THRESHOLD {
496 let nodes = rayon_run(move || {
497 nodes.retain(|node| node.verify(now));
498 nodes
499 })
500 .await;
501
502 for node in nodes {
503 handle_valid_node(node);
504 }
505 } else {
506 for node in nodes {
507 let mut signature_checked = false;
508 let is_valid = node.verify_ext(now, &mut signature_checked);
509 yield_on_complex(signature_checked).await;
510
511 if is_valid {
512 handle_valid_node(node);
513 }
514 }
515 };
516}
517
518const MAX_PARALLEL_REQUESTS: usize = 10;