Skip to main content

tycho_network/dht/
query.rs

1use std::collections::hash_map;
2use std::sync::Arc;
3use std::time::Duration;
4
5use ahash::{HashMapExt, HashSetExt};
6use anyhow::Result;
7use bytes::Bytes;
8use futures_util::stream::FuturesUnordered;
9use futures_util::{Future, StreamExt};
10use tokio::sync::Semaphore;
11use tycho_util::futures::{JoinTask, Shared, WeakSharedHandle};
12use tycho_util::sync::{rayon_run, yield_on_complex};
13use tycho_util::time::now_sec;
14use tycho_util::{FastDashMap, FastHashMap, FastHashSet};
15
16use crate::dht::config::DhtConfig;
17use crate::dht::routing::{HandlesRoutingTable, SimpleRoutingTable};
18use crate::network::Network;
19use crate::proto::dht::{NodeResponse, Value, ValueRef, ValueResponse, rpc};
20use crate::types::{PeerId, PeerInfo, Request};
21use crate::util::NetworkExt;
22
23pub struct QueryCache<R> {
24    cache: FastDashMap<[u8; 32], WeakSpawnedFut<R>>,
25}
26
27impl<R> QueryCache<R> {
28    pub async fn run<F, Fut>(&self, target_id: &[u8; 32], f: F) -> R
29    where
30        R: Clone + Send + 'static,
31        F: FnOnce() -> Fut,
32        Fut: Future<Output = R> + Send + 'static,
33    {
34        use dashmap::mapref::entry::Entry;
35
36        let fut = match self.cache.entry(*target_id) {
37            Entry::Vacant(entry) => {
38                let fut = Shared::new(JoinTask::new(f()));
39                if let Some(weak) = fut.downgrade() {
40                    entry.insert(weak);
41                }
42                fut
43            }
44            Entry::Occupied(mut entry) => {
45                if let Some(fut) = entry.get().upgrade() {
46                    fut
47                } else {
48                    let fut = Shared::new(JoinTask::new(f()));
49                    match fut.downgrade() {
50                        Some(weak) => entry.insert(weak),
51                        None => entry.remove(),
52                    };
53                    fut
54                }
55            }
56        };
57
58        fn on_drop<R>(_key: &[u8; 32], value: &WeakSpawnedFut<R>) -> bool {
59            value.strong_count() == 0
60        }
61
62        let (output, is_last) = {
63            struct Guard<'a, R> {
64                target_id: &'a [u8; 32],
65                cache: &'a FastDashMap<[u8; 32], WeakSpawnedFut<R>>,
66                fut: Option<Shared<JoinTask<R>>>,
67            }
68
69            impl<R> Drop for Guard<'_, R> {
70                fn drop(&mut self) {
71                    // Remove value from cache if we consumed the last future instance
72                    if self.fut.take().map(Shared::consume).unwrap_or_default() {
73                        self.cache.remove_if(self.target_id, on_drop);
74                    }
75                }
76            }
77
78            // Wrap future into guard to remove it from cache event it was cancelled
79            let mut guard = Guard {
80                target_id,
81                cache: &self.cache,
82                fut: None,
83            };
84            let fut = guard.fut.insert(fut);
85
86            // Await future.
87            // If `Shared` future is not polled to `Complete` state,
88            // the guard will try to consume it and remove from cache
89            // if it was the last instance.
90            fut.await
91        };
92
93        // TODO: add ttl and force others to make a request for a fresh data
94        if is_last {
95            // Remove value from cache if we consumed the last future instance
96            self.cache.remove_if(target_id, on_drop);
97        }
98
99        output
100    }
101}
102
103impl<R> Default for QueryCache<R> {
104    fn default() -> Self {
105        Self {
106            cache: Default::default(),
107        }
108    }
109}
110
111type WeakSpawnedFut<T> = WeakSharedHandle<JoinTask<T>>;
112
113#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
114pub enum DhtQueryMode {
115    #[default]
116    Closest,
117    Random,
118}
119
120pub struct Query {
121    network: Network,
122    candidates: SimpleRoutingTable,
123    max_k: usize,
124    timeout: Duration,
125}
126
127impl Query {
128    pub fn new(
129        network: Network,
130        routing_table: &HandlesRoutingTable,
131        target_id: &[u8; 32],
132        config: &DhtConfig,
133        mode: DhtQueryMode,
134    ) -> Self {
135        let mut candidates = SimpleRoutingTable::new(PeerId(*target_id));
136
137        let random_id;
138        let target_id_for_full = match mode {
139            DhtQueryMode::Closest => target_id,
140            DhtQueryMode::Random => {
141                random_id = rand::random();
142                &random_id
143            }
144        };
145
146        let max_k = config.max_k;
147        let timeout = config.request_timeout;
148
149        routing_table.visit_closest(target_id_for_full, max_k, |node| {
150            candidates.add(node.load_peer_info(), max_k, &Duration::MAX, Some);
151        });
152
153        Self {
154            network,
155            candidates,
156            max_k,
157            timeout,
158        }
159    }
160
161    fn local_id(&self) -> &[u8; 32] {
162        self.candidates.local_id.as_bytes()
163    }
164
165    #[tracing::instrument(skip_all)]
166    pub async fn find_value(mut self) -> Option<Box<Value>> {
167        // Prepare shared request
168        let request_body = Bytes::from(tl_proto::serialize(rpc::FindValue {
169            key: *self.local_id(),
170            k: self.max_k as u32,
171        }));
172
173        // Prepare request to initial candidates
174        let semaphore = Semaphore::new(MAX_PARALLEL_REQUESTS);
175        let mut futures = FuturesUnordered::new();
176        self.candidates
177            .visit_closest(self.local_id(), self.max_k, |node| {
178                futures.push(Self::visit::<ValueResponse>(
179                    self.network.clone(),
180                    node.clone(),
181                    request_body.clone(),
182                    &semaphore,
183                    self.timeout,
184                ));
185            });
186
187        // Process responses and refill futures until the value is found or all peers are traversed
188        let mut visited = FastHashSet::new();
189        while let Some((node, res)) = futures.next().await {
190            match res {
191                // Return the value if found
192                Some(Ok(ValueResponse::Found(value))) => {
193                    let mut signature_checked = false;
194                    let is_valid =
195                        value.verify_ext(now_sec(), self.local_id(), &mut signature_checked);
196                    tracing::debug!(peer_id = %node.id, is_valid, "found value");
197
198                    yield_on_complex(signature_checked).await;
199
200                    if !is_valid {
201                        // Ignore invalid values
202                        continue;
203                    }
204
205                    return Some(value);
206                }
207                // Refill futures from the nodes response
208                Some(Ok(ValueResponse::NotFound(nodes))) => {
209                    let node_count = nodes.len();
210                    let has_new = self
211                        .update_candidates(now_sec(), self.max_k, nodes, &mut visited)
212                        .await;
213                    tracing::debug!(peer_id = %node.id, count = node_count, has_new, "received nodes");
214
215                    if !has_new {
216                        // Do nothing if candidates were not changed
217                        continue;
218                    }
219
220                    // Add new nodes from the closest range
221                    self.candidates
222                        .visit_closest(self.local_id(), self.max_k, |node| {
223                            if visited.contains(&node.id) {
224                                // Skip already visited nodes
225                                return;
226                            }
227                            futures.push(Self::visit::<ValueResponse>(
228                                self.network.clone(),
229                                node.clone(),
230                                request_body.clone(),
231                                &semaphore,
232                                self.timeout,
233                            ));
234                        });
235                }
236                // Do nothing on error
237                Some(Err(e)) => {
238                    tracing::warn!(peer_id = %node.id, "failed to query nodes: {e}");
239                }
240                // Do nothing on timeout
241                None => {
242                    tracing::warn!(peer_id = %node.id, "failed to query nodes: timeout");
243                }
244            }
245        }
246
247        // Done
248        None
249    }
250
251    #[tracing::instrument(skip_all)]
252    pub async fn find_peers(mut self, depth: Option<usize>) -> FastHashMap<PeerId, Arc<PeerInfo>> {
253        // Prepare shared request
254        let request_body = Bytes::from(tl_proto::serialize(rpc::FindNode {
255            key: *self.local_id(),
256            k: self.max_k as u32,
257        }));
258
259        // Prepare request to initial candidates
260        let semaphore = Semaphore::new(MAX_PARALLEL_REQUESTS);
261        let mut futures = FuturesUnordered::new();
262        self.candidates
263            .visit_closest(self.local_id(), self.max_k, |node| {
264                futures.push(Self::visit::<NodeResponse>(
265                    self.network.clone(),
266                    node.clone(),
267                    request_body.clone(),
268                    &semaphore,
269                    self.timeout,
270                ));
271            });
272
273        // Process responses and refill futures until all peers are traversed
274        let mut current_depth = 0;
275        let max_depth = depth.unwrap_or(usize::MAX);
276        let mut result = FastHashMap::<PeerId, Arc<PeerInfo>>::new();
277        while let Some((node, res)) = futures.next().await {
278            match res {
279                // Refill futures from the nodes response
280                Some(Ok(NodeResponse { nodes })) => {
281                    tracing::debug!(peer_id = %node.id, count = nodes.len(), "received nodes");
282                    if !self
283                        .update_candidates_full(now_sec(), self.max_k, nodes, &mut result)
284                        .await
285                    {
286                        // Do nothing if candidates were not changed
287                        continue;
288                    }
289
290                    current_depth += 1;
291                    if current_depth >= max_depth {
292                        // Stop on max depth
293                        break;
294                    }
295
296                    // Add new nodes from the closest range
297                    self.candidates
298                        .visit_closest(self.local_id(), self.max_k, |node| {
299                            if result.contains_key(&node.id) {
300                                // Skip already visited nodes
301                                return;
302                            }
303                            futures.push(Self::visit::<NodeResponse>(
304                                self.network.clone(),
305                                node.clone(),
306                                request_body.clone(),
307                                &semaphore,
308                                self.timeout,
309                            ));
310                        });
311                }
312                // Do nothing on error
313                Some(Err(e)) => {
314                    tracing::warn!(peer_id = %node.id, "failed to query nodes: {e}");
315                }
316                // Do nothing on timeout
317                None => {
318                    tracing::warn!(peer_id = %node.id, "failed to query nodes: timeout");
319                }
320            }
321        }
322
323        // Done
324        result
325    }
326
327    async fn update_candidates(
328        &mut self,
329        now: u32,
330        max_k: usize,
331        nodes: Vec<Arc<PeerInfo>>,
332        visited: &mut FastHashSet<PeerId>,
333    ) -> bool {
334        let mut has_new = false;
335        process_only_valid(now, nodes, |node| {
336            // Insert a new entry
337            if visited.insert(node.id) {
338                self.candidates.add(node, max_k, &Duration::MAX, Some);
339                has_new = true;
340            }
341        })
342        .await;
343
344        has_new
345    }
346
347    async fn update_candidates_full(
348        &mut self,
349        now: u32,
350        max_k: usize,
351        nodes: Vec<Arc<PeerInfo>>,
352        visited: &mut FastHashMap<PeerId, Arc<PeerInfo>>,
353    ) -> bool {
354        let mut has_new = false;
355        process_only_valid(now, nodes, |node| {
356            match visited.entry(node.id) {
357                // Insert a new entry
358                hash_map::Entry::Vacant(entry) => {
359                    let node = entry.insert(node).clone();
360                    self.candidates.add(node, max_k, &Duration::MAX, Some);
361                    has_new = true;
362                }
363                // Try to replace an old entry
364                hash_map::Entry::Occupied(mut entry) => {
365                    if entry.get().created_at < node.created_at {
366                        *entry.get_mut() = node;
367                    }
368                }
369            }
370        })
371        .await;
372
373        has_new
374    }
375
376    async fn visit<T>(
377        network: Network,
378        node: Arc<PeerInfo>,
379        request_body: Bytes,
380        semaphore: &Semaphore,
381        timeout: Duration,
382    ) -> (Arc<PeerInfo>, Option<Result<T>>)
383    where
384        for<'a> T: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
385    {
386        let Ok(_permit) = semaphore.acquire().await else {
387            return (node, None);
388        };
389
390        let req = network.query(&node.id, Request {
391            version: Default::default(),
392            body: request_body.clone(),
393        });
394
395        let res = match tokio::time::timeout(timeout, req).await {
396            Ok(res) => {
397                Some(res.and_then(|res| tl_proto::deserialize::<T>(&res.body).map_err(Into::into)))
398            }
399            Err(_) => None,
400        };
401
402        (node, res)
403    }
404}
405
406pub struct StoreValue<F = ()> {
407    futures: FuturesUnordered<F>,
408}
409
410impl StoreValue<()> {
411    pub fn new(
412        network: Network,
413        routing_table: &HandlesRoutingTable,
414        value: &ValueRef<'_>,
415        config: &DhtConfig,
416        local_peer_info: Option<&PeerInfo>,
417    ) -> StoreValue<impl Future<Output = (Arc<PeerInfo>, Option<Result<()>>)> + Send + use<>> {
418        let key_hash = match value {
419            ValueRef::Peer(value) => tl_proto::hash(&value.key),
420            ValueRef::Merged(value) => tl_proto::hash(&value.key),
421        };
422
423        let request_body = Bytes::from(match local_peer_info {
424            Some(peer_info) => tl_proto::serialize((
425                rpc::WithPeerInfo::wrap(peer_info),
426                rpc::StoreRef::wrap(value),
427            )),
428            None => tl_proto::serialize(rpc::StoreRef::wrap(value)),
429        });
430
431        let semaphore = Arc::new(Semaphore::new(10));
432        let futures = futures_util::stream::FuturesUnordered::new();
433        routing_table.visit_closest(&key_hash, config.max_k, |node| {
434            futures.push(Self::visit(
435                network.clone(),
436                node.load_peer_info(),
437                request_body.clone(),
438                semaphore.clone(),
439                config.request_timeout,
440            ));
441        });
442
443        StoreValue { futures }
444    }
445
446    async fn visit(
447        network: Network,
448        node: Arc<PeerInfo>,
449        request_body: Bytes,
450        semaphore: Arc<Semaphore>,
451        timeout: Duration,
452    ) -> (Arc<PeerInfo>, Option<Result<()>>) {
453        let Ok(_permit) = semaphore.acquire().await else {
454            return (node, None);
455        };
456
457        let req = network.send(&node.id, Request {
458            version: Default::default(),
459            body: request_body.clone(),
460        });
461
462        let res = (tokio::time::timeout(timeout, req).await).ok();
463
464        (node, res)
465    }
466}
467
468impl<T: Future<Output = (Arc<PeerInfo>, Option<Result<()>>)> + Send> StoreValue<T> {
469    #[tracing::instrument(level = "debug", skip_all, name = "store_value")]
470    pub async fn run(mut self) {
471        while let Some((node, res)) = self.futures.next().await {
472            match res {
473                Some(Ok(())) => {
474                    tracing::debug!(peer_id = %node.id, "value stored");
475                }
476                Some(Err(e)) => {
477                    tracing::warn!(peer_id = %node.id, "failed to store value: {e}");
478                }
479                // Do nothing on timeout
480                None => {
481                    tracing::warn!(peer_id = %node.id, "failed to store value: timeout");
482                }
483            }
484        }
485    }
486}
487
488async fn process_only_valid<F>(now: u32, mut nodes: Vec<Arc<PeerInfo>>, mut handle_valid_node: F)
489where
490    F: FnMut(Arc<PeerInfo>) + Send,
491{
492    const SPAWN_THRESHOLD: usize = 4;
493
494    // NOTE: Ensure that we don't block the thread for too long
495    if nodes.len() > SPAWN_THRESHOLD {
496        let nodes = rayon_run(move || {
497            nodes.retain(|node| node.verify(now));
498            nodes
499        })
500        .await;
501
502        for node in nodes {
503            handle_valid_node(node);
504        }
505    } else {
506        for node in nodes {
507            let mut signature_checked = false;
508            let is_valid = node.verify_ext(now, &mut signature_checked);
509            yield_on_complex(signature_checked).await;
510
511            if is_valid {
512                handle_valid_node(node);
513            }
514        }
515    };
516}
517
518const MAX_PARALLEL_REQUESTS: usize = 10;