tycho_network/overlay/
private_overlay.rs

1use std::borrow::Borrow;
2use std::sync::Arc;
3
4use anyhow::Result;
5use bytes::{Bytes, BytesMut};
6use indexmap::IndexMap;
7use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
8use rand::Rng;
9use tokio::sync::broadcast;
10use tycho_util::futures::BoxFutureOrNoop;
11use tycho_util::{FastHashSet, FastHasherState};
12
13use crate::dht::{PeerResolver, PeerResolverHandle};
14use crate::network::Network;
15use crate::overlay::OverlayId;
16use crate::overlay::metrics::Metrics;
17use crate::proto::overlay::rpc;
18use crate::types::{BoxService, PeerId, Request, Response, Service, ServiceExt, ServiceRequest};
19use crate::util::NetworkExt;
20
21pub struct PrivateOverlayBuilder {
22    overlay_id: OverlayId,
23    entries: FastHashSet<PeerId>,
24    entry_events_channel_size: usize,
25    peer_resolver: Option<PeerResolver>,
26    name: Option<&'static str>,
27}
28
29impl PrivateOverlayBuilder {
30    pub fn with_entries<I>(mut self, allowed_peers: I) -> Self
31    where
32        I: IntoIterator,
33        I::Item: Borrow<PeerId>,
34    {
35        self.entries
36            .extend(allowed_peers.into_iter().map(|p| *p.borrow()));
37        self
38    }
39
40    /// The capacity of entries set events.
41    ///
42    /// Default: 100.
43    pub fn with_entry_events_channel_size(mut self, entry_events_channel_size: usize) -> Self {
44        self.entry_events_channel_size = entry_events_channel_size;
45        self
46    }
47
48    /// Whether to resolve peers with the provided resolver.
49    ///
50    /// Does not resolve peers by default.
51    pub fn with_peer_resolver(mut self, peer_resolver: PeerResolver) -> Self {
52        self.peer_resolver = Some(peer_resolver);
53        self
54    }
55
56    /// Name of the overlay used in metrics.
57    pub fn named(mut self, name: &'static str) -> Self {
58        self.name = Some(name);
59        self
60    }
61
62    pub fn build<S>(self, service: S) -> PrivateOverlay
63    where
64        S: Send + Sync + 'static,
65        S: Service<ServiceRequest, QueryResponse = Response>,
66    {
67        let request_prefix = tl_proto::serialize(rpc::Prefix {
68            overlay_id: self.overlay_id.as_bytes(),
69        });
70
71        let mut entries = PrivateOverlayEntries {
72            items: Default::default(),
73            events_tx: broadcast::channel(self.entry_events_channel_size).0,
74            peer_resolver: self.peer_resolver,
75        };
76        for peer_id in self.entries {
77            entries.insert(&peer_id);
78        }
79
80        PrivateOverlay {
81            inner: Arc::new(Inner {
82                overlay_id: self.overlay_id,
83                entries: RwLock::new(entries),
84                service: service.boxed(),
85                request_prefix: request_prefix.into_boxed_slice(),
86                metrics: self
87                    .name
88                    .map(|label| Metrics::new("tycho_private_overlay", label))
89                    .unwrap_or_default(),
90            }),
91        }
92    }
93}
94
95#[derive(Clone)]
96pub struct PrivateOverlay {
97    inner: Arc<Inner>,
98}
99
100impl PrivateOverlay {
101    pub fn builder(overlay_id: OverlayId) -> PrivateOverlayBuilder {
102        PrivateOverlayBuilder {
103            overlay_id,
104            entries: Default::default(),
105            entry_events_channel_size: 100,
106            peer_resolver: None,
107            name: None,
108        }
109    }
110
111    #[inline]
112    pub fn overlay_id(&self) -> &OverlayId {
113        &self.inner.overlay_id
114    }
115
116    pub async fn query(
117        &self,
118        network: &Network,
119        peer_id: &PeerId,
120        mut request: Request,
121    ) -> Result<Response> {
122        self.inner.metrics.record_rx(request.body.len());
123        self.prepend_prefix_to_body(&mut request.body);
124        network.query(peer_id, request).await
125    }
126
127    pub async fn send(
128        &self,
129        network: &Network,
130        peer_id: &PeerId,
131        mut request: Request,
132    ) -> Result<()> {
133        self.inner.metrics.record_rx(request.body.len());
134        self.prepend_prefix_to_body(&mut request.body);
135        network.send(peer_id, request).await
136    }
137
138    pub fn write_entries(&self) -> PrivateOverlayEntriesWriteGuard<'_> {
139        PrivateOverlayEntriesWriteGuard {
140            entries: self.inner.entries.write(),
141        }
142    }
143
144    pub fn read_entries(&self) -> PrivateOverlayEntriesReadGuard<'_> {
145        PrivateOverlayEntriesReadGuard {
146            entries: self.inner.entries.read(),
147        }
148    }
149
150    pub(crate) fn handle_query(&self, req: ServiceRequest) -> BoxFutureOrNoop<Option<Response>> {
151        self.inner.metrics.record_rx(req.body.len());
152        if self.inner.entries.read().contains(&req.metadata.peer_id) {
153            BoxFutureOrNoop::future(self.inner.service.on_query(req))
154        } else {
155            BoxFutureOrNoop::Noop
156        }
157    }
158
159    pub(crate) fn handle_message(&self, req: ServiceRequest) -> BoxFutureOrNoop<()> {
160        self.inner.metrics.record_rx(req.body.len());
161        if self.inner.entries.read().contains(&req.metadata.peer_id) {
162            BoxFutureOrNoop::future(self.inner.service.on_message(req))
163        } else {
164            BoxFutureOrNoop::Noop
165        }
166    }
167
168    fn prepend_prefix_to_body(&self, body: &mut Bytes) {
169        // TODO: reduce allocations
170        let mut res = BytesMut::with_capacity(self.inner.request_prefix.len() + body.len());
171        res.extend_from_slice(&self.inner.request_prefix);
172        res.extend_from_slice(body);
173        *body = res.freeze();
174    }
175}
176
177impl std::fmt::Debug for PrivateOverlay {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        f.debug_struct("PrivateOverlay")
180            .field("overlay_id", &self.inner.overlay_id)
181            .finish()
182    }
183}
184
185struct Inner {
186    overlay_id: OverlayId,
187    entries: RwLock<PrivateOverlayEntries>,
188    service: BoxService<ServiceRequest, Response>,
189    request_prefix: Box<[u8]>,
190    metrics: Metrics,
191}
192
193// NOTE: `#[derive(Default)]` is missing to prevent construction outside the
194// crate.
195pub struct PrivateOverlayEntries {
196    items: OverlayItems,
197    events_tx: broadcast::Sender<PrivateOverlayEntriesEvent>,
198    peer_resolver: Option<PeerResolver>,
199}
200
201impl PrivateOverlayEntries {
202    /// Subscribes to the set updates.
203    pub fn subscribe(&self) -> broadcast::Receiver<PrivateOverlayEntriesEvent> {
204        self.events_tx.subscribe()
205    }
206
207    /// Returns an iterator over the entry ids.
208    ///
209    /// The order is not random, but is not defined.
210    pub fn iter(&self) -> indexmap::map::Values<'_, PeerId, PrivateOverlayEntryData> {
211        self.items.values()
212    }
213
214    /// Returns one random peer, or `None` if set is empty.
215    pub fn choose<R>(&self, rng: &mut R) -> Option<&PrivateOverlayEntryData>
216    where
217        R: Rng + ?Sized,
218    {
219        let index = rng.random_range(0..self.items.len());
220        let (_, value) = self.items.get_index(index)?;
221        Some(value)
222    }
223
224    /// Chooses `n` entries from the set, without repetition,
225    /// and in random order.
226    pub fn choose_multiple<R>(
227        &self,
228        rng: &mut R,
229        n: usize,
230    ) -> ChooseMultiplePrivateOverlayEntries<'_>
231    where
232        R: Rng + ?Sized,
233    {
234        let len = self.items.len();
235        ChooseMultiplePrivateOverlayEntries {
236            items: &self.items,
237            indices: rand::seq::index::sample(rng, len, n.min(len)).into_iter(),
238        }
239    }
240
241    /// Clears the set, removing all entries.
242    pub fn clear(&mut self) {
243        self.items.clear();
244    }
245
246    /// Returns `true` if the set contains no elements.
247    pub fn is_empty(&self) -> bool {
248        self.items.is_empty()
249    }
250
251    /// Returns the number of elements in the set, also referred to as its 'length'.
252    pub fn len(&self) -> usize {
253        self.items.len()
254    }
255
256    /// Returns true if the set contains the specified peer id.
257    pub fn contains(&self, peer_id: &PeerId) -> bool {
258        self.items.contains_key(peer_id)
259    }
260
261    /// Returns the peer resolver handle for the specified peer id, if it exists.
262    pub fn get_handle(&self, peer_id: &PeerId) -> Option<&PeerResolverHandle> {
263        self.items.get(peer_id).map(|item| &item.resolver_handle)
264    }
265
266    /// Adds a peer id to the set.
267    ///
268    /// Returns whether the value was newly inserted.
269    pub fn insert(&mut self, peer_id: &PeerId) -> bool {
270        match self.items.entry(*peer_id) {
271            // No entry for the peer_id, insert a new one
272            indexmap::map::Entry::Vacant(entry) => {
273                let handle = self.peer_resolver.as_ref().map_or_else(
274                    || PeerResolverHandle::new_noop(peer_id),
275                    |resolver| resolver.insert(peer_id, true),
276                );
277
278                entry.insert(PrivateOverlayEntryData {
279                    peer_id: *peer_id,
280                    resolver_handle: handle,
281                });
282
283                _ = self
284                    .events_tx
285                    .send(PrivateOverlayEntriesEvent::Added(*peer_id));
286                true
287            }
288            // Entry for the peer_id exists, do nothing
289            indexmap::map::Entry::Occupied(_) => false,
290        }
291    }
292
293    /// Removes a value from the set.
294    ///
295    /// Returns whether the value was present in the set.
296    pub fn remove(&mut self, peer_id: &PeerId) -> bool {
297        let removed = self.items.swap_remove(peer_id).is_some();
298        if removed {
299            _ = self
300                .events_tx
301                .send(PrivateOverlayEntriesEvent::Removed(*peer_id));
302        }
303        removed
304    }
305}
306
307#[derive(Clone)]
308pub struct PrivateOverlayEntryData {
309    pub peer_id: PeerId,
310    pub resolver_handle: PeerResolverHandle,
311}
312
313pub struct PrivateOverlayEntriesWriteGuard<'a> {
314    entries: RwLockWriteGuard<'a, PrivateOverlayEntries>,
315}
316
317impl std::ops::Deref for PrivateOverlayEntriesWriteGuard<'_> {
318    type Target = PrivateOverlayEntries;
319
320    #[inline]
321    fn deref(&self) -> &Self::Target {
322        &self.entries
323    }
324}
325
326impl std::ops::DerefMut for PrivateOverlayEntriesWriteGuard<'_> {
327    #[inline]
328    fn deref_mut(&mut self) -> &mut Self::Target {
329        &mut self.entries
330    }
331}
332
333impl<'a> PrivateOverlayEntriesWriteGuard<'a> {
334    pub fn downgrade(self) -> PrivateOverlayEntriesReadGuard<'a> {
335        let entries = RwLockWriteGuard::downgrade(self.entries);
336        PrivateOverlayEntriesReadGuard { entries }
337    }
338}
339
340pub struct PrivateOverlayEntriesReadGuard<'a> {
341    entries: RwLockReadGuard<'a, PrivateOverlayEntries>,
342}
343
344impl std::ops::Deref for PrivateOverlayEntriesReadGuard<'_> {
345    type Target = PrivateOverlayEntries;
346
347    #[inline]
348    fn deref(&self) -> &Self::Target {
349        &self.entries
350    }
351}
352
353#[derive(Debug, Clone, PartialEq, Eq)]
354pub enum PrivateOverlayEntriesEvent {
355    /// A new entry was inserted.
356    Added(PeerId),
357    /// An existing entry was removed.
358    Removed(PeerId),
359}
360
361pub struct ChooseMultiplePrivateOverlayEntries<'a> {
362    items: &'a OverlayItems,
363    indices: rand::seq::index::IndexVecIntoIter,
364}
365
366impl<'a> Iterator for ChooseMultiplePrivateOverlayEntries<'a> {
367    type Item = &'a PrivateOverlayEntryData;
368
369    fn next(&mut self) -> Option<Self::Item> {
370        self.indices.next().and_then(|i| {
371            let (_, value) = self.items.get_index(i)?;
372            Some(value)
373        })
374    }
375
376    fn size_hint(&self) -> (usize, Option<usize>) {
377        (self.indices.len(), Some(self.indices.len()))
378    }
379}
380
381impl ExactSizeIterator for ChooseMultiplePrivateOverlayEntries<'_> {
382    fn len(&self) -> usize {
383        self.indices.len()
384    }
385}
386
387type OverlayItems = IndexMap<PeerId, PrivateOverlayEntryData, FastHasherState>;
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn entries_container_is_set() {
395        let mut entries = PrivateOverlayEntries {
396            items: Default::default(),
397            peer_resolver: None,
398            events_tx: broadcast::channel(100).0,
399        };
400        assert!(entries.is_empty());
401        assert_eq!(entries.len(), 0);
402
403        let peer_id = rand::random();
404        assert!(entries.insert(&peer_id));
405
406        assert!(!entries.is_empty());
407        assert_eq!(entries.len(), 1);
408
409        assert!(!entries.insert(&peer_id));
410        assert_eq!(entries.len(), 1);
411
412        entries.clear();
413        assert!(entries.is_empty());
414        assert_eq!(entries.len(), 0);
415    }
416
417    #[test]
418    fn remove_from_entries_container() {
419        let (events_tx, mut events_rx) = broadcast::channel(100);
420
421        let mut entries = PrivateOverlayEntries {
422            items: Default::default(),
423            peer_resolver: None,
424            events_tx,
425        };
426
427        let peer_ids = std::array::from_fn::<PeerId, 10, _>(|_| rand::random());
428        for (i, peer_id) in peer_ids.iter().enumerate() {
429            assert!(entries.insert(peer_id));
430            assert_eq!(entries.len(), i + 1);
431            assert_eq!(
432                events_rx.try_recv().unwrap(),
433                PrivateOverlayEntriesEvent::Added(*peer_id)
434            );
435        }
436
437        for peer_id in &peer_ids {
438            assert!(entries.remove(peer_id));
439            assert_eq!(
440                events_rx.try_recv().unwrap(),
441                PrivateOverlayEntriesEvent::Removed(*peer_id)
442            );
443
444            assert!(!entries.items.contains_key(peer_id));
445        }
446
447        assert!(entries.is_empty());
448
449        assert!(!entries.remove(&rand::random()));
450        assert!(events_rx.try_recv().is_err());
451    }
452}