tycho_network/overlay/
mod.rs

1use std::sync::Arc;
2
3use bytes::Buf;
4use tl_proto::{TlError, TlRead};
5use tokio::sync::Notify;
6use tycho_util::futures::BoxFutureOrNoop;
7use tycho_util::time::now_sec;
8use tycho_util::{FastDashMap, FastHashSet};
9
10pub use self::config::OverlayConfig;
11use self::entries_merger::PublicOverlayEntriesMerger;
12pub use self::overlay_id::OverlayId;
13pub use self::private_overlay::{
14    ChooseMultiplePrivateOverlayEntries, PrivateOverlay, PrivateOverlayBuilder,
15    PrivateOverlayEntries, PrivateOverlayEntriesEvent, PrivateOverlayEntriesReadGuard,
16    PrivateOverlayEntriesWriteGuard, PrivateOverlayEntryData,
17};
18pub use self::public_overlay::{
19    ChooseMultiplePublicOverlayEntries, PublicOverlay, PublicOverlayBuilder, PublicOverlayEntries,
20    PublicOverlayEntriesReadGuard, PublicOverlayEntryData,
21};
22use crate::dht::DhtService;
23use crate::network::Network;
24use crate::proto::overlay::{rpc, PublicEntriesResponse, PublicEntry};
25use crate::types::{PeerId, Response, Service, ServiceRequest};
26use crate::util::Routable;
27
28mod background_tasks;
29mod config;
30mod entries_merger;
31mod metrics;
32mod overlay_id;
33mod private_overlay;
34mod public_overlay;
35mod tasks_stream;
36
37pub struct OverlayServiceBackgroundTasks {
38    inner: Arc<OverlayServiceInner>,
39    dht: Option<DhtService>,
40}
41
42impl OverlayServiceBackgroundTasks {
43    pub fn spawn(self, network: &Network) {
44        self.inner
45            .start_background_tasks(Network::downgrade(network), self.dht);
46    }
47}
48
49pub struct OverlayServiceBuilder {
50    local_id: PeerId,
51    config: Option<OverlayConfig>,
52    dht: Option<DhtService>,
53}
54
55impl OverlayServiceBuilder {
56    pub fn with_config(mut self, config: OverlayConfig) -> Self {
57        self.config = Some(config);
58        self
59    }
60
61    pub fn with_dht_service(mut self, dht: DhtService) -> Self {
62        self.dht = Some(dht);
63        self
64    }
65
66    pub fn build(self) -> (OverlayServiceBackgroundTasks, OverlayService) {
67        let config = self.config.unwrap_or_default();
68
69        let inner = Arc::new(OverlayServiceInner {
70            local_id: self.local_id,
71            config,
72            private_overlays: Default::default(),
73            public_overlays: Default::default(),
74            public_overlays_changed: Arc::new(Notify::new()),
75            private_overlays_changed: Arc::new(Notify::new()),
76            public_entries_merger: Arc::new(PublicOverlayEntriesMerger),
77        });
78
79        let background_tasks = OverlayServiceBackgroundTasks {
80            inner: inner.clone(),
81            dht: self.dht,
82        };
83
84        (background_tasks, OverlayService(inner))
85    }
86}
87
88#[derive(Clone)]
89pub struct OverlayService(Arc<OverlayServiceInner>);
90
91impl OverlayService {
92    pub fn builder(local_id: PeerId) -> OverlayServiceBuilder {
93        OverlayServiceBuilder {
94            local_id,
95            config: None,
96            dht: None,
97        }
98    }
99
100    pub fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool {
101        self.0.add_private_overlay(overlay)
102    }
103
104    pub fn remove_private_overlay(&self, overlay_id: &OverlayId) -> bool {
105        self.0.remove_private_overlay(overlay_id)
106    }
107
108    pub fn add_public_overlay(&self, overlay: &PublicOverlay) -> bool {
109        self.0.add_public_overlay(overlay)
110    }
111
112    pub fn remove_public_overlay(&self, overlay_id: &OverlayId) -> bool {
113        self.0.remove_public_overlay(overlay_id)
114    }
115}
116
117impl Service<ServiceRequest> for OverlayService {
118    type QueryResponse = Response;
119    type OnQueryFuture = BoxFutureOrNoop<Option<Self::QueryResponse>>;
120    type OnMessageFuture = BoxFutureOrNoop<()>;
121    type OnDatagramFuture = futures_util::future::Ready<()>;
122
123    #[tracing::instrument(
124        level = "debug",
125        name = "on_overlay_query",
126        skip_all,
127        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
128    )]
129    fn on_query(&self, mut req: ServiceRequest) -> Self::OnQueryFuture {
130        let e = 'req: {
131            if req.body.len() < 4 {
132                break 'req TlError::UnexpectedEof;
133            }
134
135            // NOTE: `req.body` is untouched while reading the constructor
136            // and `as_ref` here is exactly for that.
137            let mut offset = 0;
138            let overlay_id = match req.body.as_ref().get_u32_le() {
139                rpc::Prefix::TL_ID => match rpc::Prefix::read_from(&req.body, &mut offset) {
140                    Ok(rpc::Prefix { overlay_id }) => overlay_id,
141                    Err(e) => break 'req e,
142                },
143                rpc::ExchangeRandomPublicEntries::TL_ID => {
144                    let req = match tl_proto::deserialize::<rpc::ExchangeRandomPublicEntries>(
145                        &req.body,
146                    ) {
147                        Ok(req) => req,
148                        Err(e) => break 'req e,
149                    };
150                    tracing::debug!("exchange_random_public_entries");
151
152                    let res = self.0.handle_exchange_public_entries(&req);
153                    return BoxFutureOrNoop::future(futures_util::future::ready(Some(
154                        Response::from_tl(res),
155                    )));
156                }
157                _ => break 'req TlError::UnknownConstructor,
158            };
159
160            if req.body.len() < offset + 4 {
161                // Definitely an invalid request (not enough bytes for the constructor)
162                break 'req TlError::UnexpectedEof;
163            }
164
165            if let Some(private_overlay) = self.0.private_overlays.get(overlay_id) {
166                req.body.advance(offset);
167                return private_overlay.handle_query(req);
168            } else if let Some(public_overlay) = self.0.public_overlays.get(overlay_id) {
169                req.body.advance(offset);
170                return public_overlay.handle_query(req);
171            }
172
173            tracing::debug!(
174                overlay_id = %OverlayId::wrap(overlay_id),
175                "unknown overlay id"
176            );
177            return BoxFutureOrNoop::Noop;
178        };
179
180        tracing::debug!("failed to deserialize query: {e:?}");
181        BoxFutureOrNoop::Noop
182    }
183
184    #[tracing::instrument(
185        level = "debug",
186        name = "on_overlay_message",
187        skip_all,
188        fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
189    )]
190    fn on_message(&self, mut req: ServiceRequest) -> Self::OnMessageFuture {
191        // TODO: somehow refactor with one method for both query and message
192
193        let e = 'req: {
194            if req.body.len() < 4 {
195                break 'req TlError::UnexpectedEof;
196            }
197
198            // NOTE: `req.body` is untouched while reading the constructor
199            // and `as_ref` here is exactly for that.
200            let mut offset = 0;
201            let overlay_id = match req.body.as_ref().get_u32_le() {
202                rpc::Prefix::TL_ID => match rpc::Prefix::read_from(&req.body, &mut offset) {
203                    Ok(rpc::Prefix { overlay_id }) => overlay_id,
204                    Err(e) => break 'req e,
205                },
206                _ => break 'req TlError::UnknownConstructor,
207            };
208
209            if req.body.len() < offset + 4 {
210                // Definitely an invalid request (not enough bytes for the constructor)
211                break 'req TlError::UnexpectedEof;
212            }
213
214            if let Some(private_overlay) = self.0.private_overlays.get(overlay_id) {
215                req.body.advance(offset);
216                return private_overlay.handle_message(req);
217            } else if let Some(public_overlay) = self.0.public_overlays.get(overlay_id) {
218                req.body.advance(offset);
219                return public_overlay.handle_message(req);
220            }
221
222            tracing::debug!(
223                overlay_id = %OverlayId::wrap(overlay_id),
224                "unknown overlay id"
225            );
226            return BoxFutureOrNoop::Noop;
227        };
228
229        tracing::debug!("failed to deserialize message: {e:?}");
230        BoxFutureOrNoop::Noop
231    }
232
233    #[inline]
234    fn on_datagram(&self, _req: ServiceRequest) -> Self::OnDatagramFuture {
235        futures_util::future::ready(())
236    }
237}
238
239impl Routable for OverlayService {
240    fn query_ids(&self) -> impl IntoIterator<Item = u32> {
241        [rpc::ExchangeRandomPublicEntries::TL_ID, rpc::Prefix::TL_ID]
242    }
243
244    fn message_ids(&self) -> impl IntoIterator<Item = u32> {
245        [rpc::Prefix::TL_ID]
246    }
247}
248
249struct OverlayServiceInner {
250    local_id: PeerId,
251    config: OverlayConfig,
252    public_overlays: FastDashMap<OverlayId, PublicOverlay>,
253    private_overlays: FastDashMap<OverlayId, PrivateOverlay>,
254    public_overlays_changed: Arc<Notify>,
255    private_overlays_changed: Arc<Notify>,
256    public_entries_merger: Arc<PublicOverlayEntriesMerger>,
257}
258
259impl OverlayServiceInner {
260    fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool {
261        use dashmap::mapref::entry::Entry;
262
263        if self.public_overlays.contains_key(overlay.overlay_id()) {
264            return false;
265        }
266        match self.private_overlays.entry(*overlay.overlay_id()) {
267            Entry::Vacant(entry) => {
268                entry.insert(overlay.clone());
269                self.private_overlays_changed.notify_waiters();
270                true
271            }
272            Entry::Occupied(_) => false,
273        }
274    }
275
276    fn remove_private_overlay(&self, overlay_id: &OverlayId) -> bool {
277        let removed = self.private_overlays.remove(overlay_id).is_some();
278        if removed {
279            self.private_overlays_changed.notify_waiters();
280        }
281        removed
282    }
283
284    fn add_public_overlay(&self, overlay: &PublicOverlay) -> bool {
285        use dashmap::mapref::entry::Entry;
286
287        if self.private_overlays.contains_key(overlay.overlay_id()) {
288            return false;
289        }
290        match self.public_overlays.entry(*overlay.overlay_id()) {
291            Entry::Vacant(entry) => {
292                entry.insert(overlay.clone());
293                self.public_overlays_changed.notify_waiters();
294                true
295            }
296            Entry::Occupied(_) => false,
297        }
298    }
299
300    fn remove_public_overlay(&self, overlay_id: &OverlayId) -> bool {
301        let removed = self.public_overlays.remove(overlay_id).is_some();
302        if removed {
303            self.public_overlays_changed.notify_waiters();
304        }
305        removed
306    }
307
308    fn handle_exchange_public_entries(
309        &self,
310        req: &rpc::ExchangeRandomPublicEntries,
311    ) -> PublicEntriesResponse {
312        // NOTE: validation is done in the TL parser.
313        debug_assert!(req.entries.len() <= 20);
314
315        // Find the overlay
316        let overlay = match self.public_overlays.get(&req.overlay_id) {
317            Some(overlay) => overlay,
318            None => return PublicEntriesResponse::OverlayNotFound,
319        };
320
321        // Add proposed entries to the overlay
322        overlay.add_untrusted_entries(&self.local_id, &req.entries, now_sec());
323
324        // Collect proposed entries to exclude from the response
325        let requested_ids = req
326            .entries
327            .iter()
328            .map(|id| id.peer_id)
329            .collect::<FastHashSet<_>>();
330
331        let entries = {
332            let entries = overlay.read_entries();
333
334            // Choose additional random entries to ensure we have enough new entries to send back
335            let n = self.config.exchange_public_entries_batch;
336            entries
337                .choose_multiple(&mut rand::thread_rng(), n + requested_ids.len())
338                .filter_map(|item| {
339                    let is_new = !requested_ids.contains(&item.entry.peer_id);
340                    is_new.then(|| item.entry.clone())
341                })
342                .take(n)
343                .collect::<Vec<_>>()
344        };
345
346        PublicEntriesResponse::PublicEntries(entries)
347    }
348}