tycho_network/overlay/
mod.rs

1use std::sync::Arc;
2
3use bytes::Buf;
4use tl_proto::{TlError, TlRead};
5use tokio::sync::Notify;
6use tycho_util::futures::BoxFutureOrNoop;
7use tycho_util::time::now_sec;
8use tycho_util::{FastDashMap, FastHashSet};
9
10pub use self::config::OverlayConfig;
11use self::entries_merger::PublicOverlayEntriesMerger;
12pub use self::overlay_id::OverlayId;
13pub use self::private_overlay::{
14    ChooseMultiplePrivateOverlayEntries, PrivateOverlay, PrivateOverlayBuilder,
15    PrivateOverlayEntries, PrivateOverlayEntriesEvent, PrivateOverlayEntriesReadGuard,
16    PrivateOverlayEntriesWriteGuard, PrivateOverlayEntryData,
17};
18pub use self::public_overlay::{
19    ChooseMultiplePublicOverlayEntries, PublicOverlay, PublicOverlayBuilder, PublicOverlayEntries,
20    PublicOverlayEntriesReadGuard, PublicOverlayEntryData, UnknownPeersQueue,
21};
22use crate::dht::DhtService;
23use crate::network::Network;
24use crate::proto::overlay::{PublicEntriesResponse, PublicEntry, PublicEntryResponse, rpc};
25use crate::types::{PeerId, Response, Service, ServiceRequest};
26use crate::util::Routable;
27
28mod background_tasks;
29mod config;
30mod entries_merger;
31mod metrics;
32mod overlay_id;
33mod private_overlay;
34mod public_overlay;
35mod tasks_stream;
36
37pub struct OverlayServiceBackgroundTasks {
38    inner: Arc<OverlayServiceInner>,
39    dht: Option<DhtService>,
40}
41
42impl OverlayServiceBackgroundTasks {
43    pub fn spawn(self, network: &Network) {
44        self.inner
45            .start_background_tasks(Network::downgrade(network), self.dht);
46    }
47}
48
49pub struct OverlayServiceBuilder {
50    local_id: PeerId,
51    config: Option<OverlayConfig>,
52    dht: Option<DhtService>,
53}
54
55impl OverlayServiceBuilder {
56    pub fn with_config(mut self, config: OverlayConfig) -> Self {
57        self.config = Some(config);
58        self
59    }
60
61    pub fn with_dht_service(mut self, dht: DhtService) -> Self {
62        self.dht = Some(dht);
63        self
64    }
65
66    pub fn build(self) -> (OverlayServiceBackgroundTasks, OverlayService) {
67        let config = self.config.unwrap_or_default();
68
69        let inner = Arc::new(OverlayServiceInner {
70            local_id: self.local_id,
71            config,
72            private_overlays: Default::default(),
73            public_overlays: Default::default(),
74            public_overlays_changed: Arc::new(Notify::new()),
75            private_overlays_changed: Arc::new(Notify::new()),
76            public_entries_merger: Arc::new(PublicOverlayEntriesMerger),
77        });
78
79        let background_tasks = OverlayServiceBackgroundTasks {
80            inner: inner.clone(),
81            dht: self.dht,
82        };
83
84        (background_tasks, OverlayService(inner))
85    }
86}
87
88#[derive(Clone)]
89pub struct OverlayService(Arc<OverlayServiceInner>);
90
91impl OverlayService {
92    pub fn builder(local_id: PeerId) -> OverlayServiceBuilder {
93        OverlayServiceBuilder {
94            local_id,
95            config: None,
96            dht: None,
97        }
98    }
99
100    pub fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool {
101        self.0.add_private_overlay(overlay)
102    }
103
104    pub fn remove_private_overlay(&self, overlay_id: &OverlayId) -> bool {
105        self.0.remove_private_overlay(overlay_id)
106    }
107
108    pub fn add_public_overlay(&self, overlay: &PublicOverlay) -> bool {
109        self.0.add_public_overlay(overlay)
110    }
111
112    pub fn remove_public_overlay(&self, overlay_id: &OverlayId) -> bool {
113        self.0.remove_public_overlay(overlay_id)
114    }
115}
116
117impl Service<ServiceRequest> for OverlayService {
118    type QueryResponse = Response;
119    type OnQueryFuture = BoxFutureOrNoop<Option<Self::QueryResponse>>;
120    type OnMessageFuture = BoxFutureOrNoop<()>;
121
122    #[tracing::instrument(
123        level = "debug",
124        name = "on_overlay_query",
125        skip_all,
126        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
127    )]
128    fn on_query(&self, mut req: ServiceRequest) -> Self::OnQueryFuture {
129        let e = 'req: {
130            let mut req_body = req.body.as_ref();
131            if req_body.len() < 4 {
132                break 'req TlError::UnexpectedEof;
133            }
134
135            let overlay_id = match std::convert::identity(req_body).get_u32_le() {
136                rpc::Prefix::TL_ID => match rpc::Prefix::read_from(&mut req_body) {
137                    Ok(rpc::Prefix { overlay_id }) => overlay_id,
138                    Err(e) => break 'req e,
139                },
140                rpc::ExchangeRandomPublicEntries::TL_ID => {
141                    let req = match tl_proto::deserialize::<rpc::ExchangeRandomPublicEntries>(
142                        &req.body,
143                    ) {
144                        Ok(req) => req,
145                        Err(e) => break 'req e,
146                    };
147                    tracing::debug!("exchange_random_public_entries");
148
149                    let res = self.0.handle_exchange_public_entries(&req);
150                    return BoxFutureOrNoop::future(futures_util::future::ready(Some(
151                        Response::from_tl(res),
152                    )));
153                }
154                rpc::GetPublicEntry::TL_ID => {
155                    let req = match tl_proto::deserialize::<rpc::GetPublicEntry>(&req.body) {
156                        Ok(req) => req,
157                        Err(e) => break 'req e,
158                    };
159                    tracing::debug!("get_public_entry");
160
161                    let res = self.0.handle_get_public_entry(&req);
162                    return BoxFutureOrNoop::future(futures_util::future::ready(Some(
163                        Response::from_tl(res),
164                    )));
165                }
166                _ => break 'req TlError::UnknownConstructor,
167            };
168
169            if req_body.len() < 4 {
170                // Definitely an invalid request (not enough bytes for the constructor)
171                break 'req TlError::UnexpectedEof;
172            }
173            let offset = req.body.len() - req_body.len();
174
175            if let Some(private_overlay) = self.0.private_overlays.get(overlay_id) {
176                req.body.advance(offset);
177                return private_overlay.handle_query(req);
178            } else if let Some(public_overlay) = self.0.public_overlays.get(overlay_id) {
179                req.body.advance(offset);
180                return public_overlay.handle_query(req);
181            }
182
183            tracing::debug!(
184                overlay_id = %OverlayId::wrap(overlay_id),
185                "unknown overlay id"
186            );
187            return BoxFutureOrNoop::Noop;
188        };
189
190        tracing::debug!("failed to deserialize query: {e:?}");
191        BoxFutureOrNoop::Noop
192    }
193
194    #[tracing::instrument(
195        level = "debug",
196        name = "on_overlay_message",
197        skip_all,
198        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
199    )]
200    fn on_message(&self, mut req: ServiceRequest) -> Self::OnMessageFuture {
201        // TODO: somehow refactor with one method for both query and message
202
203        let e = 'req: {
204            let mut req_body = req.body.as_ref();
205            if req_body.len() < 4 {
206                break 'req TlError::UnexpectedEof;
207            }
208
209            let overlay_id = match std::convert::identity(req_body).get_u32_le() {
210                rpc::Prefix::TL_ID => match rpc::Prefix::read_from(&mut req_body) {
211                    Ok(rpc::Prefix { overlay_id }) => overlay_id,
212                    Err(e) => break 'req e,
213                },
214                _ => break 'req TlError::UnknownConstructor,
215            };
216
217            if req_body.len() < 4 {
218                // Definitely an invalid request (not enough bytes for the constructor)
219                break 'req TlError::UnexpectedEof;
220            }
221            let offset = req.body.len() - req_body.len();
222
223            if let Some(private_overlay) = self.0.private_overlays.get(overlay_id) {
224                req.body.advance(offset);
225                return private_overlay.handle_message(req);
226            } else if let Some(public_overlay) = self.0.public_overlays.get(overlay_id) {
227                req.body.advance(offset);
228                return public_overlay.handle_message(req);
229            }
230
231            tracing::debug!(
232                overlay_id = %OverlayId::wrap(overlay_id),
233                "unknown overlay id"
234            );
235            return BoxFutureOrNoop::Noop;
236        };
237
238        tracing::debug!("failed to deserialize message: {e:?}");
239        BoxFutureOrNoop::Noop
240    }
241}
242
243impl Routable for OverlayService {
244    fn query_ids(&self) -> impl IntoIterator<Item = u32> {
245        [
246            rpc::ExchangeRandomPublicEntries::TL_ID,
247            rpc::GetPublicEntry::TL_ID,
248            rpc::Prefix::TL_ID,
249        ]
250    }
251
252    fn message_ids(&self) -> impl IntoIterator<Item = u32> {
253        [rpc::Prefix::TL_ID]
254    }
255}
256
257struct OverlayServiceInner {
258    local_id: PeerId,
259    config: OverlayConfig,
260    public_overlays: FastDashMap<OverlayId, PublicOverlay>,
261    private_overlays: FastDashMap<OverlayId, PrivateOverlay>,
262    public_overlays_changed: Arc<Notify>,
263    private_overlays_changed: Arc<Notify>,
264    public_entries_merger: Arc<PublicOverlayEntriesMerger>,
265}
266
267impl OverlayServiceInner {
268    fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool {
269        use dashmap::mapref::entry::Entry;
270
271        if self.public_overlays.contains_key(overlay.overlay_id()) {
272            return false;
273        }
274        match self.private_overlays.entry(*overlay.overlay_id()) {
275            Entry::Vacant(entry) => {
276                entry.insert(overlay.clone());
277                self.private_overlays_changed.notify_waiters();
278                true
279            }
280            Entry::Occupied(_) => false,
281        }
282    }
283
284    fn remove_private_overlay(&self, overlay_id: &OverlayId) -> bool {
285        let removed = self.private_overlays.remove(overlay_id).is_some();
286        if removed {
287            self.private_overlays_changed.notify_waiters();
288        }
289        removed
290    }
291
292    fn add_public_overlay(&self, overlay: &PublicOverlay) -> bool {
293        use dashmap::mapref::entry::Entry;
294
295        if self.private_overlays.contains_key(overlay.overlay_id()) {
296            return false;
297        }
298        match self.public_overlays.entry(*overlay.overlay_id()) {
299            Entry::Vacant(entry) => {
300                entry.insert(overlay.clone());
301                self.public_overlays_changed.notify_waiters();
302                true
303            }
304            Entry::Occupied(_) => false,
305        }
306    }
307
308    fn remove_public_overlay(&self, overlay_id: &OverlayId) -> bool {
309        let removed = self.public_overlays.remove(overlay_id).is_some();
310        if removed {
311            self.public_overlays_changed.notify_waiters();
312        }
313        removed
314    }
315
316    fn handle_exchange_public_entries(
317        &self,
318        req: &rpc::ExchangeRandomPublicEntries,
319    ) -> PublicEntriesResponse {
320        // NOTE: validation is done in the TL parser.
321        debug_assert!(req.entries.len() <= 20);
322
323        // Find the overlay
324        let overlay = match self.public_overlays.get(&req.overlay_id) {
325            Some(overlay) => overlay,
326            None => return PublicEntriesResponse::OverlayNotFound,
327        };
328
329        // Add proposed entries to the overlay
330        overlay.add_untrusted_entries(&self.local_id, &req.entries, now_sec());
331
332        // Collect proposed entries to exclude from the response
333        let requested_ids = req
334            .entries
335            .iter()
336            .map(|id| id.peer_id)
337            .collect::<FastHashSet<_>>();
338
339        let entries = {
340            let entries = overlay.read_entries();
341
342            // Choose additional random entries to ensure we have enough new entries to send back
343            let n = self.config.exchange_public_entries_batch;
344            entries
345                .choose_multiple(&mut rand::rng(), n + requested_ids.len())
346                .filter_map(|item| {
347                    let is_new = !requested_ids.contains(&item.entry.peer_id);
348                    is_new.then(|| item.entry.clone())
349                })
350                .take(n)
351                .collect::<Vec<_>>()
352        };
353
354        PublicEntriesResponse::PublicEntries(entries)
355    }
356
357    fn handle_get_public_entry(&self, req: &rpc::GetPublicEntry) -> PublicEntryResponse {
358        // Find the overlay
359        let overlay = match self.public_overlays.get(&req.overlay_id) {
360            Some(overlay) => overlay,
361            None => return PublicEntryResponse::OverlayNotFound,
362        };
363
364        let Some(entry) = overlay.own_signed_entry() else {
365            // NOTE: We return `OverlayNotFound` because if there is no signed entry
366            // stored, then the background tasks are not running at all and this is
367            // kind of "shadow" mode which is identical to not being in the overlay.
368            return PublicEntryResponse::OverlayNotFound;
369        };
370
371        PublicEntryResponse::Found(entry)
372    }
373}