1use std::collections::hash_map;
2use std::sync::Arc;
3use std::time::Duration;
4
5use ahash::{HashMapExt, HashSetExt};
6use anyhow::Result;
7use bytes::Bytes;
8use futures_util::stream::FuturesUnordered;
9use futures_util::{Future, StreamExt};
10use tokio::sync::Semaphore;
11use tycho_util::futures::{JoinTask, Shared, WeakSharedHandle};
12use tycho_util::sync::{rayon_run, yield_on_complex};
13use tycho_util::time::now_sec;
14use tycho_util::{FastDashMap, FastHashMap, FastHashSet};
15
16use crate::dht::routing::{HandlesRoutingTable, SimpleRoutingTable};
17use crate::network::Network;
18use crate::proto::dht::{NodeResponse, Value, ValueRef, ValueResponse, rpc};
19use crate::types::{PeerId, PeerInfo, Request};
20use crate::util::NetworkExt;
21
22pub struct QueryCache<R> {
23 cache: FastDashMap<[u8; 32], WeakSpawnedFut<R>>,
24}
25
26impl<R> QueryCache<R> {
27 pub async fn run<F, Fut>(&self, target_id: &[u8; 32], f: F) -> R
28 where
29 R: Clone + Send + 'static,
30 F: FnOnce() -> Fut,
31 Fut: Future<Output = R> + Send + 'static,
32 {
33 use dashmap::mapref::entry::Entry;
34
35 let fut = match self.cache.entry(*target_id) {
36 Entry::Vacant(entry) => {
37 let fut = Shared::new(JoinTask::new(f()));
38 if let Some(weak) = fut.downgrade() {
39 entry.insert(weak);
40 }
41 fut
42 }
43 Entry::Occupied(mut entry) => {
44 if let Some(fut) = entry.get().upgrade() {
45 fut
46 } else {
47 let fut = Shared::new(JoinTask::new(f()));
48 match fut.downgrade() {
49 Some(weak) => entry.insert(weak),
50 None => entry.remove(),
51 };
52 fut
53 }
54 }
55 };
56
57 fn on_drop<R>(_key: &[u8; 32], value: &WeakSpawnedFut<R>) -> bool {
58 value.strong_count() == 0
59 }
60
61 let (output, is_last) = {
62 struct Guard<'a, R> {
63 target_id: &'a [u8; 32],
64 cache: &'a FastDashMap<[u8; 32], WeakSpawnedFut<R>>,
65 fut: Option<Shared<JoinTask<R>>>,
66 }
67
68 impl<R> Drop for Guard<'_, R> {
69 fn drop(&mut self) {
70 if self.fut.take().map(Shared::consume).unwrap_or_default() {
72 self.cache.remove_if(self.target_id, on_drop);
73 }
74 }
75 }
76
77 let mut guard = Guard {
79 target_id,
80 cache: &self.cache,
81 fut: None,
82 };
83 let fut = guard.fut.insert(fut);
84
85 fut.await
90 };
91
92 if is_last {
94 self.cache.remove_if(target_id, on_drop);
96 }
97
98 output
99 }
100}
101
102impl<R> Default for QueryCache<R> {
103 fn default() -> Self {
104 Self {
105 cache: Default::default(),
106 }
107 }
108}
109
110type WeakSpawnedFut<T> = WeakSharedHandle<JoinTask<T>>;
111
112#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
113pub enum DhtQueryMode {
114 #[default]
115 Closest,
116 Random,
117}
118
119pub struct Query {
120 network: Network,
121 candidates: SimpleRoutingTable,
122 max_k: usize,
123}
124
125impl Query {
126 pub fn new(
127 network: Network,
128 routing_table: &HandlesRoutingTable,
129 target_id: &[u8; 32],
130 max_k: usize,
131 mode: DhtQueryMode,
132 ) -> Self {
133 let mut candidates = SimpleRoutingTable::new(PeerId(*target_id));
134
135 let random_id;
136 let target_id_for_full = match mode {
137 DhtQueryMode::Closest => target_id,
138 DhtQueryMode::Random => {
139 random_id = rand::random();
140 &random_id
141 }
142 };
143
144 routing_table.visit_closest(target_id_for_full, max_k, |node| {
145 candidates.add(node.load_peer_info(), max_k, &Duration::MAX, Some);
146 });
147
148 Self {
149 network,
150 candidates,
151 max_k,
152 }
153 }
154
155 fn local_id(&self) -> &[u8; 32] {
156 self.candidates.local_id.as_bytes()
157 }
158
159 #[tracing::instrument(skip_all)]
160 pub async fn find_value(mut self) -> Option<Box<Value>> {
161 let request_body = Bytes::from(tl_proto::serialize(rpc::FindValue {
163 key: *self.local_id(),
164 k: self.max_k as u32,
165 }));
166
167 let semaphore = Semaphore::new(MAX_PARALLEL_REQUESTS);
169 let mut futures = FuturesUnordered::new();
170 self.candidates
171 .visit_closest(self.local_id(), self.max_k, |node| {
172 futures.push(Self::visit::<ValueResponse>(
173 self.network.clone(),
174 node.clone(),
175 request_body.clone(),
176 &semaphore,
177 ));
178 });
179
180 let mut visited = FastHashSet::new();
182 while let Some((node, res)) = futures.next().await {
183 match res {
184 Some(Ok(ValueResponse::Found(value))) => {
186 let mut signature_checked = false;
187 let is_valid =
188 value.verify_ext(now_sec(), self.local_id(), &mut signature_checked);
189 tracing::debug!(peer_id = %node.id, is_valid, "found value");
190
191 yield_on_complex(signature_checked).await;
192
193 if !is_valid {
194 continue;
196 }
197
198 return Some(value);
199 }
200 Some(Ok(ValueResponse::NotFound(nodes))) => {
202 let node_count = nodes.len();
203 let has_new = self
204 .update_candidates(now_sec(), self.max_k, nodes, &mut visited)
205 .await;
206 tracing::debug!(peer_id = %node.id, count = node_count, has_new, "received nodes");
207
208 if !has_new {
209 continue;
211 }
212
213 self.candidates
215 .visit_closest(self.local_id(), self.max_k, |node| {
216 if visited.contains(&node.id) {
217 return;
219 }
220 futures.push(Self::visit::<ValueResponse>(
221 self.network.clone(),
222 node.clone(),
223 request_body.clone(),
224 &semaphore,
225 ));
226 });
227 }
228 Some(Err(e)) => {
230 tracing::warn!(peer_id = %node.id, "failed to query nodes: {e}");
231 }
232 None => {
234 tracing::warn!(peer_id = %node.id, "failed to query nodes: timeout");
235 }
236 }
237 }
238
239 None
241 }
242
243 #[tracing::instrument(skip_all)]
244 pub async fn find_peers(mut self, depth: Option<usize>) -> FastHashMap<PeerId, Arc<PeerInfo>> {
245 let request_body = Bytes::from(tl_proto::serialize(rpc::FindNode {
247 key: *self.local_id(),
248 k: self.max_k as u32,
249 }));
250
251 let semaphore = Semaphore::new(MAX_PARALLEL_REQUESTS);
253 let mut futures = FuturesUnordered::new();
254 self.candidates
255 .visit_closest(self.local_id(), self.max_k, |node| {
256 futures.push(Self::visit::<NodeResponse>(
257 self.network.clone(),
258 node.clone(),
259 request_body.clone(),
260 &semaphore,
261 ));
262 });
263
264 let mut current_depth = 0;
266 let max_depth = depth.unwrap_or(usize::MAX);
267 let mut result = FastHashMap::<PeerId, Arc<PeerInfo>>::new();
268 while let Some((node, res)) = futures.next().await {
269 match res {
270 Some(Ok(NodeResponse { nodes })) => {
272 tracing::debug!(peer_id = %node.id, count = nodes.len(), "received nodes");
273 if !self
274 .update_candidates_full(now_sec(), self.max_k, nodes, &mut result)
275 .await
276 {
277 continue;
279 }
280
281 current_depth += 1;
282 if current_depth >= max_depth {
283 break;
285 }
286
287 self.candidates
289 .visit_closest(self.local_id(), self.max_k, |node| {
290 if result.contains_key(&node.id) {
291 return;
293 }
294 futures.push(Self::visit::<NodeResponse>(
295 self.network.clone(),
296 node.clone(),
297 request_body.clone(),
298 &semaphore,
299 ));
300 });
301 }
302 Some(Err(e)) => {
304 tracing::warn!(peer_id = %node.id, "failed to query nodes: {e}");
305 }
306 None => {
308 tracing::warn!(peer_id = %node.id, "failed to query nodes: timeout");
309 }
310 }
311 }
312
313 result
315 }
316
317 async fn update_candidates(
318 &mut self,
319 now: u32,
320 max_k: usize,
321 nodes: Vec<Arc<PeerInfo>>,
322 visited: &mut FastHashSet<PeerId>,
323 ) -> bool {
324 let mut has_new = false;
325 process_only_valid(now, nodes, |node| {
326 if visited.insert(node.id) {
328 self.candidates.add(node, max_k, &Duration::MAX, Some);
329 has_new = true;
330 }
331 })
332 .await;
333
334 has_new
335 }
336
337 async fn update_candidates_full(
338 &mut self,
339 now: u32,
340 max_k: usize,
341 nodes: Vec<Arc<PeerInfo>>,
342 visited: &mut FastHashMap<PeerId, Arc<PeerInfo>>,
343 ) -> bool {
344 let mut has_new = false;
345 process_only_valid(now, nodes, |node| {
346 match visited.entry(node.id) {
347 hash_map::Entry::Vacant(entry) => {
349 let node = entry.insert(node).clone();
350 self.candidates.add(node, max_k, &Duration::MAX, Some);
351 has_new = true;
352 }
353 hash_map::Entry::Occupied(mut entry) => {
355 if entry.get().created_at < node.created_at {
356 *entry.get_mut() = node;
357 }
358 }
359 }
360 })
361 .await;
362
363 has_new
364 }
365
366 async fn visit<T>(
367 network: Network,
368 node: Arc<PeerInfo>,
369 request_body: Bytes,
370 semaphore: &Semaphore,
371 ) -> (Arc<PeerInfo>, Option<Result<T>>)
372 where
373 for<'a> T: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
374 {
375 let Ok(_permit) = semaphore.acquire().await else {
376 return (node, None);
377 };
378
379 let req = network.query(&node.id, Request {
380 version: Default::default(),
381 body: request_body.clone(),
382 });
383
384 let res = match tokio::time::timeout(REQUEST_TIMEOUT, req).await {
385 Ok(res) => {
386 Some(res.and_then(|res| tl_proto::deserialize::<T>(&res.body).map_err(Into::into)))
387 }
388 Err(_) => None,
389 };
390
391 (node, res)
392 }
393}
394
395pub struct StoreValue<F = ()> {
396 futures: FuturesUnordered<F>,
397}
398
399impl StoreValue<()> {
400 pub fn new(
401 network: Network,
402 routing_table: &HandlesRoutingTable,
403 value: &ValueRef<'_>,
404 max_k: usize,
405 local_peer_info: Option<&PeerInfo>,
406 ) -> StoreValue<impl Future<Output = (Arc<PeerInfo>, Option<Result<()>>)> + Send + use<>> {
407 let key_hash = match value {
408 ValueRef::Peer(value) => tl_proto::hash(&value.key),
409 ValueRef::Merged(value) => tl_proto::hash(&value.key),
410 };
411
412 let request_body = Bytes::from(match local_peer_info {
413 Some(peer_info) => tl_proto::serialize((
414 rpc::WithPeerInfo::wrap(peer_info),
415 rpc::StoreRef::wrap(value),
416 )),
417 None => tl_proto::serialize(rpc::StoreRef::wrap(value)),
418 });
419
420 let semaphore = Arc::new(Semaphore::new(10));
421 let futures = futures_util::stream::FuturesUnordered::new();
422 routing_table.visit_closest(&key_hash, max_k, |node| {
423 futures.push(Self::visit(
424 network.clone(),
425 node.load_peer_info(),
426 request_body.clone(),
427 semaphore.clone(),
428 ));
429 });
430
431 StoreValue { futures }
432 }
433
434 async fn visit(
435 network: Network,
436 node: Arc<PeerInfo>,
437 request_body: Bytes,
438 semaphore: Arc<Semaphore>,
439 ) -> (Arc<PeerInfo>, Option<Result<()>>) {
440 let Ok(_permit) = semaphore.acquire().await else {
441 return (node, None);
442 };
443
444 let req = network.send(&node.id, Request {
445 version: Default::default(),
446 body: request_body.clone(),
447 });
448
449 let res = (tokio::time::timeout(REQUEST_TIMEOUT, req).await).ok();
450
451 (node, res)
452 }
453}
454
455impl<T: Future<Output = (Arc<PeerInfo>, Option<Result<()>>)> + Send> StoreValue<T> {
456 #[tracing::instrument(level = "debug", skip_all, name = "store_value")]
457 pub async fn run(mut self) {
458 while let Some((node, res)) = self.futures.next().await {
459 match res {
460 Some(Ok(())) => {
461 tracing::debug!(peer_id = %node.id, "value stored");
462 }
463 Some(Err(e)) => {
464 tracing::warn!(peer_id = %node.id, "failed to store value: {e}");
465 }
466 None => {
468 tracing::warn!(peer_id = %node.id, "failed to store value: timeout");
469 }
470 }
471 }
472 }
473}
474
475async fn process_only_valid<F>(now: u32, mut nodes: Vec<Arc<PeerInfo>>, mut handle_valid_node: F)
476where
477 F: FnMut(Arc<PeerInfo>) + Send,
478{
479 const SPAWN_THRESHOLD: usize = 4;
480
481 if nodes.len() > SPAWN_THRESHOLD {
483 let nodes = rayon_run(move || {
484 nodes.retain(|node| node.verify(now));
485 nodes
486 })
487 .await;
488
489 for node in nodes {
490 handle_valid_node(node);
491 }
492 } else {
493 for node in nodes {
494 let mut signature_checked = false;
495 let is_valid = node.verify_ext(now, &mut signature_checked);
496 yield_on_complex(signature_checked).await;
497
498 if is_valid {
499 handle_valid_node(node);
500 }
501 }
502 };
503}
504
505const REQUEST_TIMEOUT: Duration = Duration::from_millis(500);
506const MAX_PARALLEL_REQUESTS: usize = 10;