tycho_network/dht/
query.rs

1use std::collections::hash_map;
2use std::sync::Arc;
3use std::time::Duration;
4
5use ahash::{HashMapExt, HashSetExt};
6use anyhow::Result;
7use bytes::Bytes;
8use futures_util::stream::FuturesUnordered;
9use futures_util::{Future, StreamExt};
10use tokio::sync::Semaphore;
11use tycho_util::futures::{JoinTask, Shared, WeakSharedHandle};
12use tycho_util::sync::{rayon_run, yield_on_complex};
13use tycho_util::time::now_sec;
14use tycho_util::{FastDashMap, FastHashMap, FastHashSet};
15
16use crate::dht::routing::{HandlesRoutingTable, SimpleRoutingTable};
17use crate::network::Network;
18use crate::proto::dht::{NodeResponse, Value, ValueRef, ValueResponse, rpc};
19use crate::types::{PeerId, PeerInfo, Request};
20use crate::util::NetworkExt;
21
22pub struct QueryCache<R> {
23    cache: FastDashMap<[u8; 32], WeakSpawnedFut<R>>,
24}
25
26impl<R> QueryCache<R> {
27    pub async fn run<F, Fut>(&self, target_id: &[u8; 32], f: F) -> R
28    where
29        R: Clone + Send + 'static,
30        F: FnOnce() -> Fut,
31        Fut: Future<Output = R> + Send + 'static,
32    {
33        use dashmap::mapref::entry::Entry;
34
35        let fut = match self.cache.entry(*target_id) {
36            Entry::Vacant(entry) => {
37                let fut = Shared::new(JoinTask::new(f()));
38                if let Some(weak) = fut.downgrade() {
39                    entry.insert(weak);
40                }
41                fut
42            }
43            Entry::Occupied(mut entry) => {
44                if let Some(fut) = entry.get().upgrade() {
45                    fut
46                } else {
47                    let fut = Shared::new(JoinTask::new(f()));
48                    match fut.downgrade() {
49                        Some(weak) => entry.insert(weak),
50                        None => entry.remove(),
51                    };
52                    fut
53                }
54            }
55        };
56
57        fn on_drop<R>(_key: &[u8; 32], value: &WeakSpawnedFut<R>) -> bool {
58            value.strong_count() == 0
59        }
60
61        let (output, is_last) = {
62            struct Guard<'a, R> {
63                target_id: &'a [u8; 32],
64                cache: &'a FastDashMap<[u8; 32], WeakSpawnedFut<R>>,
65                fut: Option<Shared<JoinTask<R>>>,
66            }
67
68            impl<R> Drop for Guard<'_, R> {
69                fn drop(&mut self) {
70                    // Remove value from cache if we consumed the last future instance
71                    if self.fut.take().map(Shared::consume).unwrap_or_default() {
72                        self.cache.remove_if(self.target_id, on_drop);
73                    }
74                }
75            }
76
77            // Wrap future into guard to remove it from cache event it was cancelled
78            let mut guard = Guard {
79                target_id,
80                cache: &self.cache,
81                fut: None,
82            };
83            let fut = guard.fut.insert(fut);
84
85            // Await future.
86            // If `Shared` future is not polled to `Complete` state,
87            // the guard will try to consume it and remove from cache
88            // if it was the last instance.
89            fut.await
90        };
91
92        // TODO: add ttl and force others to make a request for a fresh data
93        if is_last {
94            // Remove value from cache if we consumed the last future instance
95            self.cache.remove_if(target_id, on_drop);
96        }
97
98        output
99    }
100}
101
102impl<R> Default for QueryCache<R> {
103    fn default() -> Self {
104        Self {
105            cache: Default::default(),
106        }
107    }
108}
109
110type WeakSpawnedFut<T> = WeakSharedHandle<JoinTask<T>>;
111
112#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
113pub enum DhtQueryMode {
114    #[default]
115    Closest,
116    Random,
117}
118
119pub struct Query {
120    network: Network,
121    candidates: SimpleRoutingTable,
122    max_k: usize,
123}
124
125impl Query {
126    pub fn new(
127        network: Network,
128        routing_table: &HandlesRoutingTable,
129        target_id: &[u8; 32],
130        max_k: usize,
131        mode: DhtQueryMode,
132    ) -> Self {
133        let mut candidates = SimpleRoutingTable::new(PeerId(*target_id));
134
135        let random_id;
136        let target_id_for_full = match mode {
137            DhtQueryMode::Closest => target_id,
138            DhtQueryMode::Random => {
139                random_id = rand::random();
140                &random_id
141            }
142        };
143
144        routing_table.visit_closest(target_id_for_full, max_k, |node| {
145            candidates.add(node.load_peer_info(), max_k, &Duration::MAX, Some);
146        });
147
148        Self {
149            network,
150            candidates,
151            max_k,
152        }
153    }
154
155    fn local_id(&self) -> &[u8; 32] {
156        self.candidates.local_id.as_bytes()
157    }
158
159    #[tracing::instrument(skip_all)]
160    pub async fn find_value(mut self) -> Option<Box<Value>> {
161        // Prepare shared request
162        let request_body = Bytes::from(tl_proto::serialize(rpc::FindValue {
163            key: *self.local_id(),
164            k: self.max_k as u32,
165        }));
166
167        // Prepare request to initial candidates
168        let semaphore = Semaphore::new(MAX_PARALLEL_REQUESTS);
169        let mut futures = FuturesUnordered::new();
170        self.candidates
171            .visit_closest(self.local_id(), self.max_k, |node| {
172                futures.push(Self::visit::<ValueResponse>(
173                    self.network.clone(),
174                    node.clone(),
175                    request_body.clone(),
176                    &semaphore,
177                ));
178            });
179
180        // Process responses and refill futures until the value is found or all peers are traversed
181        let mut visited = FastHashSet::new();
182        while let Some((node, res)) = futures.next().await {
183            match res {
184                // Return the value if found
185                Some(Ok(ValueResponse::Found(value))) => {
186                    let mut signature_checked = false;
187                    let is_valid =
188                        value.verify_ext(now_sec(), self.local_id(), &mut signature_checked);
189                    tracing::debug!(peer_id = %node.id, is_valid, "found value");
190
191                    yield_on_complex(signature_checked).await;
192
193                    if !is_valid {
194                        // Ignore invalid values
195                        continue;
196                    }
197
198                    return Some(value);
199                }
200                // Refill futures from the nodes response
201                Some(Ok(ValueResponse::NotFound(nodes))) => {
202                    let node_count = nodes.len();
203                    let has_new = self
204                        .update_candidates(now_sec(), self.max_k, nodes, &mut visited)
205                        .await;
206                    tracing::debug!(peer_id = %node.id, count = node_count, has_new, "received nodes");
207
208                    if !has_new {
209                        // Do nothing if candidates were not changed
210                        continue;
211                    }
212
213                    // Add new nodes from the closest range
214                    self.candidates
215                        .visit_closest(self.local_id(), self.max_k, |node| {
216                            if visited.contains(&node.id) {
217                                // Skip already visited nodes
218                                return;
219                            }
220                            futures.push(Self::visit::<ValueResponse>(
221                                self.network.clone(),
222                                node.clone(),
223                                request_body.clone(),
224                                &semaphore,
225                            ));
226                        });
227                }
228                // Do nothing on error
229                Some(Err(e)) => {
230                    tracing::warn!(peer_id = %node.id, "failed to query nodes: {e}");
231                }
232                // Do nothing on timeout
233                None => {
234                    tracing::warn!(peer_id = %node.id, "failed to query nodes: timeout");
235                }
236            }
237        }
238
239        // Done
240        None
241    }
242
243    #[tracing::instrument(skip_all)]
244    pub async fn find_peers(mut self, depth: Option<usize>) -> FastHashMap<PeerId, Arc<PeerInfo>> {
245        // Prepare shared request
246        let request_body = Bytes::from(tl_proto::serialize(rpc::FindNode {
247            key: *self.local_id(),
248            k: self.max_k as u32,
249        }));
250
251        // Prepare request to initial candidates
252        let semaphore = Semaphore::new(MAX_PARALLEL_REQUESTS);
253        let mut futures = FuturesUnordered::new();
254        self.candidates
255            .visit_closest(self.local_id(), self.max_k, |node| {
256                futures.push(Self::visit::<NodeResponse>(
257                    self.network.clone(),
258                    node.clone(),
259                    request_body.clone(),
260                    &semaphore,
261                ));
262            });
263
264        // Process responses and refill futures until all peers are traversed
265        let mut current_depth = 0;
266        let max_depth = depth.unwrap_or(usize::MAX);
267        let mut result = FastHashMap::<PeerId, Arc<PeerInfo>>::new();
268        while let Some((node, res)) = futures.next().await {
269            match res {
270                // Refill futures from the nodes response
271                Some(Ok(NodeResponse { nodes })) => {
272                    tracing::debug!(peer_id = %node.id, count = nodes.len(), "received nodes");
273                    if !self
274                        .update_candidates_full(now_sec(), self.max_k, nodes, &mut result)
275                        .await
276                    {
277                        // Do nothing if candidates were not changed
278                        continue;
279                    }
280
281                    current_depth += 1;
282                    if current_depth >= max_depth {
283                        // Stop on max depth
284                        break;
285                    }
286
287                    // Add new nodes from the closest range
288                    self.candidates
289                        .visit_closest(self.local_id(), self.max_k, |node| {
290                            if result.contains_key(&node.id) {
291                                // Skip already visited nodes
292                                return;
293                            }
294                            futures.push(Self::visit::<NodeResponse>(
295                                self.network.clone(),
296                                node.clone(),
297                                request_body.clone(),
298                                &semaphore,
299                            ));
300                        });
301                }
302                // Do nothing on error
303                Some(Err(e)) => {
304                    tracing::warn!(peer_id = %node.id, "failed to query nodes: {e}");
305                }
306                // Do nothing on timeout
307                None => {
308                    tracing::warn!(peer_id = %node.id, "failed to query nodes: timeout");
309                }
310            }
311        }
312
313        // Done
314        result
315    }
316
317    async fn update_candidates(
318        &mut self,
319        now: u32,
320        max_k: usize,
321        nodes: Vec<Arc<PeerInfo>>,
322        visited: &mut FastHashSet<PeerId>,
323    ) -> bool {
324        let mut has_new = false;
325        process_only_valid(now, nodes, |node| {
326            // Insert a new entry
327            if visited.insert(node.id) {
328                self.candidates.add(node, max_k, &Duration::MAX, Some);
329                has_new = true;
330            }
331        })
332        .await;
333
334        has_new
335    }
336
337    async fn update_candidates_full(
338        &mut self,
339        now: u32,
340        max_k: usize,
341        nodes: Vec<Arc<PeerInfo>>,
342        visited: &mut FastHashMap<PeerId, Arc<PeerInfo>>,
343    ) -> bool {
344        let mut has_new = false;
345        process_only_valid(now, nodes, |node| {
346            match visited.entry(node.id) {
347                // Insert a new entry
348                hash_map::Entry::Vacant(entry) => {
349                    let node = entry.insert(node).clone();
350                    self.candidates.add(node, max_k, &Duration::MAX, Some);
351                    has_new = true;
352                }
353                // Try to replace an old entry
354                hash_map::Entry::Occupied(mut entry) => {
355                    if entry.get().created_at < node.created_at {
356                        *entry.get_mut() = node;
357                    }
358                }
359            }
360        })
361        .await;
362
363        has_new
364    }
365
366    async fn visit<T>(
367        network: Network,
368        node: Arc<PeerInfo>,
369        request_body: Bytes,
370        semaphore: &Semaphore,
371    ) -> (Arc<PeerInfo>, Option<Result<T>>)
372    where
373        for<'a> T: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
374    {
375        let Ok(_permit) = semaphore.acquire().await else {
376            return (node, None);
377        };
378
379        let req = network.query(&node.id, Request {
380            version: Default::default(),
381            body: request_body.clone(),
382        });
383
384        let res = match tokio::time::timeout(REQUEST_TIMEOUT, req).await {
385            Ok(res) => {
386                Some(res.and_then(|res| tl_proto::deserialize::<T>(&res.body).map_err(Into::into)))
387            }
388            Err(_) => None,
389        };
390
391        (node, res)
392    }
393}
394
395pub struct StoreValue<F = ()> {
396    futures: FuturesUnordered<F>,
397}
398
399impl StoreValue<()> {
400    pub fn new(
401        network: Network,
402        routing_table: &HandlesRoutingTable,
403        value: &ValueRef<'_>,
404        max_k: usize,
405        local_peer_info: Option<&PeerInfo>,
406    ) -> StoreValue<impl Future<Output = (Arc<PeerInfo>, Option<Result<()>>)> + Send + use<>> {
407        let key_hash = match value {
408            ValueRef::Peer(value) => tl_proto::hash(&value.key),
409            ValueRef::Merged(value) => tl_proto::hash(&value.key),
410        };
411
412        let request_body = Bytes::from(match local_peer_info {
413            Some(peer_info) => tl_proto::serialize((
414                rpc::WithPeerInfo::wrap(peer_info),
415                rpc::StoreRef::wrap(value),
416            )),
417            None => tl_proto::serialize(rpc::StoreRef::wrap(value)),
418        });
419
420        let semaphore = Arc::new(Semaphore::new(10));
421        let futures = futures_util::stream::FuturesUnordered::new();
422        routing_table.visit_closest(&key_hash, max_k, |node| {
423            futures.push(Self::visit(
424                network.clone(),
425                node.load_peer_info(),
426                request_body.clone(),
427                semaphore.clone(),
428            ));
429        });
430
431        StoreValue { futures }
432    }
433
434    async fn visit(
435        network: Network,
436        node: Arc<PeerInfo>,
437        request_body: Bytes,
438        semaphore: Arc<Semaphore>,
439    ) -> (Arc<PeerInfo>, Option<Result<()>>) {
440        let Ok(_permit) = semaphore.acquire().await else {
441            return (node, None);
442        };
443
444        let req = network.send(&node.id, Request {
445            version: Default::default(),
446            body: request_body.clone(),
447        });
448
449        let res = (tokio::time::timeout(REQUEST_TIMEOUT, req).await).ok();
450
451        (node, res)
452    }
453}
454
455impl<T: Future<Output = (Arc<PeerInfo>, Option<Result<()>>)> + Send> StoreValue<T> {
456    #[tracing::instrument(level = "debug", skip_all, name = "store_value")]
457    pub async fn run(mut self) {
458        while let Some((node, res)) = self.futures.next().await {
459            match res {
460                Some(Ok(())) => {
461                    tracing::debug!(peer_id = %node.id, "value stored");
462                }
463                Some(Err(e)) => {
464                    tracing::warn!(peer_id = %node.id, "failed to store value: {e}");
465                }
466                // Do nothing on timeout
467                None => {
468                    tracing::warn!(peer_id = %node.id, "failed to store value: timeout");
469                }
470            }
471        }
472    }
473}
474
475async fn process_only_valid<F>(now: u32, mut nodes: Vec<Arc<PeerInfo>>, mut handle_valid_node: F)
476where
477    F: FnMut(Arc<PeerInfo>) + Send,
478{
479    const SPAWN_THRESHOLD: usize = 4;
480
481    // NOTE: Ensure that we don't block the thread for too long
482    if nodes.len() > SPAWN_THRESHOLD {
483        let nodes = rayon_run(move || {
484            nodes.retain(|node| node.verify(now));
485            nodes
486        })
487        .await;
488
489        for node in nodes {
490            handle_valid_node(node);
491        }
492    } else {
493        for node in nodes {
494            let mut signature_checked = false;
495            let is_valid = node.verify_ext(now, &mut signature_checked);
496            yield_on_complex(signature_checked).await;
497
498            if is_valid {
499                handle_valid_node(node);
500            }
501        }
502    };
503}
504
505const REQUEST_TIMEOUT: Duration = Duration::from_millis(500);
506const MAX_PARALLEL_REQUESTS: usize = 10;