tycho_network/overlay/
mod.rs

1use std::sync::Arc;
2
3use bytes::Buf;
4use tl_proto::{TlError, TlRead};
5use tokio::sync::Notify;
6use tycho_util::futures::BoxFutureOrNoop;
7use tycho_util::time::now_sec;
8use tycho_util::{FastDashMap, FastHashMap, FastHashSet};
9
10pub use self::config::OverlayConfig;
11use self::entries_merger::PublicOverlayEntriesMerger;
12pub use self::overlay_id::OverlayId;
13pub use self::private_overlay::{
14    ChooseMultiplePrivateOverlayEntries, PrivateOverlay, PrivateOverlayBuilder,
15    PrivateOverlayEntries, PrivateOverlayEntriesEvent, PrivateOverlayEntriesReadGuard,
16    PrivateOverlayEntriesWriteGuard, PrivateOverlayEntryData,
17};
18pub use self::public_overlay::{
19    ChooseMultiplePublicOverlayEntries, PublicOverlay, PublicOverlayBuilder, PublicOverlayEntries,
20    PublicOverlayEntriesReadGuard, PublicOverlayEntryData, UnknownPeersQueue,
21};
22use crate::dht::DhtService;
23use crate::network::Network;
24use crate::proto::overlay::{PublicEntriesResponse, PublicEntry, PublicEntryResponse, rpc};
25use crate::types::{PeerId, Response, Service, ServiceRequest};
26use crate::util::Routable;
27
28mod background_tasks;
29mod config;
30mod entries_merger;
31mod metrics;
32mod overlay_id;
33mod private_overlay;
34mod public_overlay;
35mod tasks_stream;
36
37pub struct OverlayServiceBackgroundTasks {
38    inner: Arc<OverlayServiceInner>,
39    dht: Option<DhtService>,
40}
41
42impl OverlayServiceBackgroundTasks {
43    pub fn spawn(self, network: &Network) {
44        self.inner
45            .start_background_tasks(Network::downgrade(network), self.dht);
46    }
47}
48
49pub struct OverlayServiceBuilder {
50    local_id: PeerId,
51    config: Option<OverlayConfig>,
52    dht: Option<DhtService>,
53}
54
55impl OverlayServiceBuilder {
56    pub fn with_config(mut self, config: OverlayConfig) -> Self {
57        self.config = Some(config);
58        self
59    }
60
61    pub fn with_dht_service(mut self, dht: DhtService) -> Self {
62        self.dht = Some(dht);
63        self
64    }
65
66    pub fn build(self) -> (OverlayServiceBackgroundTasks, OverlayService) {
67        let config = self.config.unwrap_or_default();
68
69        let inner = Arc::new(OverlayServiceInner {
70            local_id: self.local_id,
71            config,
72            private_overlays: Default::default(),
73            public_overlays: Default::default(),
74            public_overlays_changed: Arc::new(Notify::new()),
75            private_overlays_changed: Arc::new(Notify::new()),
76            public_entries_merger: Arc::new(PublicOverlayEntriesMerger),
77        });
78
79        let background_tasks = OverlayServiceBackgroundTasks {
80            inner: inner.clone(),
81            dht: self.dht,
82        };
83
84        (background_tasks, OverlayService(inner))
85    }
86}
87
88#[derive(Clone)]
89pub struct OverlayService(Arc<OverlayServiceInner>);
90
91impl OverlayService {
92    pub fn builder(local_id: PeerId) -> OverlayServiceBuilder {
93        OverlayServiceBuilder {
94            local_id,
95            config: None,
96            dht: None,
97        }
98    }
99
100    pub fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool {
101        self.0.add_private_overlay(overlay)
102    }
103
104    pub fn remove_private_overlay(&self, overlay_id: &OverlayId) -> bool {
105        self.0.remove_private_overlay(overlay_id)
106    }
107
108    pub fn add_public_overlay(&self, overlay: &PublicOverlay) -> bool {
109        self.0.add_public_overlay(overlay)
110    }
111
112    pub fn remove_public_overlay(&self, overlay_id: &OverlayId) -> bool {
113        self.0.remove_public_overlay(overlay_id)
114    }
115
116    pub fn public_overlays(&self) -> FastHashMap<OverlayId, PublicOverlay> {
117        self.0
118            .public_overlays
119            .iter()
120            .map(|item| (*item.key(), item.value().clone()))
121            .collect()
122    }
123
124    pub fn private_overlays(&self) -> FastHashMap<OverlayId, PrivateOverlay> {
125        self.0
126            .private_overlays
127            .iter()
128            .map(|item| (*item.key(), item.value().clone()))
129            .collect()
130    }
131}
132
133impl Service<ServiceRequest> for OverlayService {
134    type QueryResponse = Response;
135    type OnQueryFuture = BoxFutureOrNoop<Option<Self::QueryResponse>>;
136    type OnMessageFuture = BoxFutureOrNoop<()>;
137
138    #[tracing::instrument(
139        level = "debug",
140        name = "on_overlay_query",
141        skip_all,
142        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
143    )]
144    fn on_query(&self, mut req: ServiceRequest) -> Self::OnQueryFuture {
145        let e = 'req: {
146            let mut req_body = req.body.as_ref();
147            if req_body.len() < 4 {
148                break 'req TlError::UnexpectedEof;
149            }
150
151            let overlay_id = match std::convert::identity(req_body).get_u32_le() {
152                rpc::Prefix::TL_ID => match rpc::Prefix::read_from(&mut req_body) {
153                    Ok(rpc::Prefix { overlay_id }) => overlay_id,
154                    Err(e) => break 'req e,
155                },
156                rpc::ExchangeRandomPublicEntries::TL_ID => {
157                    let req = match tl_proto::deserialize::<rpc::ExchangeRandomPublicEntries>(
158                        &req.body,
159                    ) {
160                        Ok(req) => req,
161                        Err(e) => break 'req e,
162                    };
163                    tracing::debug!("exchange_random_public_entries");
164
165                    let res = self.0.handle_exchange_public_entries(&req);
166                    return BoxFutureOrNoop::future(futures_util::future::ready(Some(
167                        Response::from_tl(res),
168                    )));
169                }
170                rpc::GetPublicEntry::TL_ID => {
171                    let req = match tl_proto::deserialize::<rpc::GetPublicEntry>(&req.body) {
172                        Ok(req) => req,
173                        Err(e) => break 'req e,
174                    };
175                    tracing::debug!("get_public_entry");
176
177                    let res = self.0.handle_get_public_entry(&req);
178                    return BoxFutureOrNoop::future(futures_util::future::ready(Some(
179                        Response::from_tl(res),
180                    )));
181                }
182                _ => break 'req TlError::UnknownConstructor,
183            };
184
185            if req_body.len() < 4 {
186                // Definitely an invalid request (not enough bytes for the constructor)
187                break 'req TlError::UnexpectedEof;
188            }
189            let offset = req.body.len() - req_body.len();
190
191            if let Some(private_overlay) = self.0.private_overlays.get(overlay_id) {
192                req.body.advance(offset);
193                return private_overlay.handle_query(req);
194            } else if let Some(public_overlay) = self.0.public_overlays.get(overlay_id) {
195                req.body.advance(offset);
196                return public_overlay.handle_query(req);
197            }
198
199            tracing::debug!(
200                overlay_id = %OverlayId::wrap(overlay_id),
201                "unknown overlay id"
202            );
203            return BoxFutureOrNoop::Noop;
204        };
205
206        tracing::debug!("failed to deserialize query: {e:?}");
207        BoxFutureOrNoop::Noop
208    }
209
210    #[tracing::instrument(
211        level = "debug",
212        name = "on_overlay_message",
213        skip_all,
214        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
215    )]
216    fn on_message(&self, mut req: ServiceRequest) -> Self::OnMessageFuture {
217        // TODO: somehow refactor with one method for both query and message
218
219        let e = 'req: {
220            let mut req_body = req.body.as_ref();
221            if req_body.len() < 4 {
222                break 'req TlError::UnexpectedEof;
223            }
224
225            let overlay_id = match std::convert::identity(req_body).get_u32_le() {
226                rpc::Prefix::TL_ID => match rpc::Prefix::read_from(&mut req_body) {
227                    Ok(rpc::Prefix { overlay_id }) => overlay_id,
228                    Err(e) => break 'req e,
229                },
230                _ => break 'req TlError::UnknownConstructor,
231            };
232
233            if req_body.len() < 4 {
234                // Definitely an invalid request (not enough bytes for the constructor)
235                break 'req TlError::UnexpectedEof;
236            }
237            let offset = req.body.len() - req_body.len();
238
239            if let Some(private_overlay) = self.0.private_overlays.get(overlay_id) {
240                req.body.advance(offset);
241                return private_overlay.handle_message(req);
242            } else if let Some(public_overlay) = self.0.public_overlays.get(overlay_id) {
243                req.body.advance(offset);
244                return public_overlay.handle_message(req);
245            }
246
247            tracing::debug!(
248                overlay_id = %OverlayId::wrap(overlay_id),
249                "unknown overlay id"
250            );
251            return BoxFutureOrNoop::Noop;
252        };
253
254        tracing::debug!("failed to deserialize message: {e:?}");
255        BoxFutureOrNoop::Noop
256    }
257}
258
259impl Routable for OverlayService {
260    fn query_ids(&self) -> impl IntoIterator<Item = u32> {
261        [
262            rpc::ExchangeRandomPublicEntries::TL_ID,
263            rpc::GetPublicEntry::TL_ID,
264            rpc::Prefix::TL_ID,
265        ]
266    }
267
268    fn message_ids(&self) -> impl IntoIterator<Item = u32> {
269        [rpc::Prefix::TL_ID]
270    }
271}
272
273struct OverlayServiceInner {
274    local_id: PeerId,
275    config: OverlayConfig,
276    public_overlays: FastDashMap<OverlayId, PublicOverlay>,
277    private_overlays: FastDashMap<OverlayId, PrivateOverlay>,
278    public_overlays_changed: Arc<Notify>,
279    private_overlays_changed: Arc<Notify>,
280    public_entries_merger: Arc<PublicOverlayEntriesMerger>,
281}
282
283impl OverlayServiceInner {
284    fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool {
285        use dashmap::mapref::entry::Entry;
286
287        if self.public_overlays.contains_key(overlay.overlay_id()) {
288            return false;
289        }
290        match self.private_overlays.entry(*overlay.overlay_id()) {
291            Entry::Vacant(entry) => {
292                entry.insert(overlay.clone());
293                self.private_overlays_changed.notify_waiters();
294                true
295            }
296            Entry::Occupied(_) => false,
297        }
298    }
299
300    fn remove_private_overlay(&self, overlay_id: &OverlayId) -> bool {
301        let removed = self.private_overlays.remove(overlay_id).is_some();
302        if removed {
303            self.private_overlays_changed.notify_waiters();
304        }
305        removed
306    }
307
308    fn add_public_overlay(&self, overlay: &PublicOverlay) -> bool {
309        use dashmap::mapref::entry::Entry;
310
311        if self.private_overlays.contains_key(overlay.overlay_id()) {
312            return false;
313        }
314        match self.public_overlays.entry(*overlay.overlay_id()) {
315            Entry::Vacant(entry) => {
316                entry.insert(overlay.clone());
317                self.public_overlays_changed.notify_waiters();
318                true
319            }
320            Entry::Occupied(_) => false,
321        }
322    }
323
324    fn remove_public_overlay(&self, overlay_id: &OverlayId) -> bool {
325        let removed = self.public_overlays.remove(overlay_id).is_some();
326        if removed {
327            self.public_overlays_changed.notify_waiters();
328        }
329        removed
330    }
331
332    fn handle_exchange_public_entries(
333        &self,
334        req: &rpc::ExchangeRandomPublicEntries,
335    ) -> PublicEntriesResponse {
336        // NOTE: validation is done in the TL parser.
337        debug_assert!(req.entries.len() <= 20);
338
339        // Find the overlay
340        let overlay = match self.public_overlays.get(&req.overlay_id) {
341            Some(overlay) => overlay,
342            None => return PublicEntriesResponse::OverlayNotFound,
343        };
344
345        // Add proposed entries to the overlay
346        overlay.add_untrusted_entries(&self.local_id, &req.entries, now_sec());
347
348        // Collect proposed entries to exclude from the response
349        let requested_ids = req
350            .entries
351            .iter()
352            .map(|id| id.peer_id)
353            .collect::<FastHashSet<_>>();
354
355        let entries = {
356            let entries = overlay.read_entries();
357
358            // Choose additional random entries to ensure we have enough new entries to send back
359            let n = self.config.exchange_public_entries_batch;
360            entries
361                .choose_multiple(&mut rand::rng(), n + requested_ids.len())
362                .filter_map(|item| {
363                    let is_new = !requested_ids.contains(&item.entry.peer_id);
364                    is_new.then(|| item.entry.clone())
365                })
366                .take(n)
367                .collect::<Vec<_>>()
368        };
369
370        PublicEntriesResponse::PublicEntries(entries)
371    }
372
373    fn handle_get_public_entry(&self, req: &rpc::GetPublicEntry) -> PublicEntryResponse {
374        // Find the overlay
375        let overlay = match self.public_overlays.get(&req.overlay_id) {
376            Some(overlay) => overlay,
377            None => return PublicEntryResponse::OverlayNotFound,
378        };
379
380        let Some(entry) = overlay.own_signed_entry() else {
381            // NOTE: We return `OverlayNotFound` because if there is no signed entry
382            // stored, then the background tasks are not running at all and this is
383            // kind of "shadow" mode which is identical to not being in the overlay.
384            return PublicEntryResponse::OverlayNotFound;
385        };
386
387        PublicEntryResponse::Found(entry)
388    }
389}