Skip to main content

tycho_network/overlay/
private_overlay.rs

1use std::borrow::Borrow;
2use std::sync::Arc;
3
4use anyhow::Result;
5use indexmap::IndexMap;
6use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
7use rand::Rng;
8use tokio::sync::broadcast;
9use tycho_util::futures::BoxFutureOrNoop;
10use tycho_util::{FastHashSet, FastHasherState};
11
12use crate::PrefixedRequest;
13use crate::dht::{PeerResolver, PeerResolverHandle};
14use crate::network::Network;
15use crate::overlay::OverlayId;
16use crate::overlay::metrics::Metrics;
17use crate::proto::overlay::rpc;
18use crate::types::{BoxService, PeerId, Response, Service, ServiceExt, ServiceRequest};
19use crate::util::NetworkExt;
20
21pub struct PrivateOverlayBuilder {
22    overlay_id: OverlayId,
23    entries: FastHashSet<PeerId>,
24    entry_events_channel_size: usize,
25    peer_resolver: Option<PeerResolver>,
26    name: Option<&'static str>,
27}
28
29impl PrivateOverlayBuilder {
30    pub fn with_entries<I>(mut self, allowed_peers: I) -> Self
31    where
32        I: IntoIterator,
33        I::Item: Borrow<PeerId>,
34    {
35        self.entries
36            .extend(allowed_peers.into_iter().map(|p| *p.borrow()));
37        self
38    }
39
40    /// The capacity of entries set events.
41    ///
42    /// Default: 100.
43    pub fn with_entry_events_channel_size(mut self, entry_events_channel_size: usize) -> Self {
44        self.entry_events_channel_size = entry_events_channel_size;
45        self
46    }
47
48    /// Whether to resolve peers with the provided resolver.
49    ///
50    /// Does not resolve peers by default.
51    pub fn with_peer_resolver(mut self, peer_resolver: PeerResolver) -> Self {
52        self.peer_resolver = Some(peer_resolver);
53        self
54    }
55
56    /// Name of the overlay used in metrics.
57    pub fn named(mut self, name: &'static str) -> Self {
58        self.name = Some(name);
59        self
60    }
61
62    pub fn build<S>(self, service: S) -> PrivateOverlay
63    where
64        S: Send + Sync + 'static,
65        S: Service<ServiceRequest, QueryResponse = Response>,
66    {
67        let request_prefix = tl_proto::serialize(rpc::Prefix {
68            overlay_id: self.overlay_id.as_bytes(),
69        });
70
71        let mut entries = PrivateOverlayEntries {
72            items: Default::default(),
73            events_tx: broadcast::channel(self.entry_events_channel_size).0,
74            peer_resolver: self.peer_resolver,
75        };
76        for peer_id in self.entries {
77            entries.insert(&peer_id);
78        }
79
80        PrivateOverlay {
81            inner: Arc::new(Inner {
82                overlay_id: self.overlay_id,
83                entries: RwLock::new(entries),
84                service: service.boxed(),
85                request_prefix: request_prefix.into_boxed_slice(),
86                metrics: self
87                    .name
88                    .map(|label| Metrics::new("tycho_private_overlay", label))
89                    .unwrap_or_default(),
90            }),
91        }
92    }
93}
94
95#[derive(Clone)]
96pub struct PrivateOverlay {
97    inner: Arc<Inner>,
98}
99
100impl PrivateOverlay {
101    pub fn builder(overlay_id: OverlayId) -> PrivateOverlayBuilder {
102        PrivateOverlayBuilder {
103            overlay_id,
104            entries: Default::default(),
105            entry_events_channel_size: 100,
106            peer_resolver: None,
107            name: None,
108        }
109    }
110
111    #[inline]
112    pub fn overlay_id(&self) -> &OverlayId {
113        &self.inner.overlay_id
114    }
115
116    pub fn request_from_tl<T>(&self, body: T) -> PrefixedRequest
117    where
118        T: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
119    {
120        PrefixedRequest::from_tl(&self.inner.request_prefix, body)
121    }
122
123    pub async fn query(
124        &self,
125        network: &Network,
126        peer_id: &PeerId,
127        request: PrefixedRequest,
128    ) -> Result<Response> {
129        self.inner.metrics.record_rx(request.body_len());
130        network.query(peer_id, request.into()).await
131    }
132
133    pub async fn send(
134        &self,
135        network: &Network,
136        peer_id: &PeerId,
137        request: PrefixedRequest,
138    ) -> Result<()> {
139        self.inner.metrics.record_rx(request.body_len());
140        network.send(peer_id, request.into()).await
141    }
142
143    pub fn write_entries(&self) -> PrivateOverlayEntriesWriteGuard<'_> {
144        PrivateOverlayEntriesWriteGuard {
145            entries: self.inner.entries.write(),
146        }
147    }
148
149    pub fn read_entries(&self) -> PrivateOverlayEntriesReadGuard<'_> {
150        PrivateOverlayEntriesReadGuard {
151            entries: self.inner.entries.read(),
152        }
153    }
154
155    pub(crate) fn handle_query(&self, req: ServiceRequest) -> BoxFutureOrNoop<Option<Response>> {
156        self.inner.metrics.record_rx(req.body.len());
157        if self.inner.entries.read().contains(&req.metadata.peer_id) {
158            BoxFutureOrNoop::future(self.inner.service.on_query(req))
159        } else {
160            BoxFutureOrNoop::Noop
161        }
162    }
163
164    pub(crate) fn handle_message(&self, req: ServiceRequest) -> BoxFutureOrNoop<()> {
165        self.inner.metrics.record_rx(req.body.len());
166        if self.inner.entries.read().contains(&req.metadata.peer_id) {
167            BoxFutureOrNoop::future(self.inner.service.on_message(req))
168        } else {
169            BoxFutureOrNoop::Noop
170        }
171    }
172}
173
174impl std::fmt::Debug for PrivateOverlay {
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        f.debug_struct("PrivateOverlay")
177            .field("overlay_id", &self.inner.overlay_id)
178            .finish()
179    }
180}
181
182struct Inner {
183    overlay_id: OverlayId,
184    entries: RwLock<PrivateOverlayEntries>,
185    service: BoxService<ServiceRequest, Response>,
186    request_prefix: Box<[u8]>,
187    metrics: Metrics,
188}
189
190// NOTE: `#[derive(Default)]` is missing to prevent construction outside the
191// crate.
192pub struct PrivateOverlayEntries {
193    items: OverlayItems,
194    events_tx: broadcast::Sender<PrivateOverlayEntriesEvent>,
195    peer_resolver: Option<PeerResolver>,
196}
197
198impl PrivateOverlayEntries {
199    /// Subscribes to the set updates.
200    pub fn subscribe(&self) -> broadcast::Receiver<PrivateOverlayEntriesEvent> {
201        self.events_tx.subscribe()
202    }
203
204    /// Returns an iterator over the entry ids.
205    ///
206    /// The order is not random, but is not defined.
207    pub fn iter(&self) -> indexmap::map::Values<'_, PeerId, PrivateOverlayEntryData> {
208        self.items.values()
209    }
210
211    /// Returns one random peer, or `None` if set is empty.
212    pub fn choose<R>(&self, rng: &mut R) -> Option<&PrivateOverlayEntryData>
213    where
214        R: Rng + ?Sized,
215    {
216        let index = rng.random_range(0..self.items.len());
217        let (_, value) = self.items.get_index(index)?;
218        Some(value)
219    }
220
221    /// Chooses `n` entries from the set, without repetition,
222    /// and in random order.
223    pub fn choose_multiple<R>(
224        &self,
225        rng: &mut R,
226        n: usize,
227    ) -> ChooseMultiplePrivateOverlayEntries<'_>
228    where
229        R: Rng + ?Sized,
230    {
231        let len = self.items.len();
232        ChooseMultiplePrivateOverlayEntries {
233            items: &self.items,
234            indices: rand::seq::index::sample(rng, len, n.min(len)).into_iter(),
235        }
236    }
237
238    /// Clears the set, removing all entries.
239    pub fn clear(&mut self) {
240        self.items.clear();
241    }
242
243    /// Returns `true` if the set contains no elements.
244    pub fn is_empty(&self) -> bool {
245        self.items.is_empty()
246    }
247
248    /// Returns the number of elements in the set, also referred to as its 'length'.
249    pub fn len(&self) -> usize {
250        self.items.len()
251    }
252
253    /// Returns true if the set contains the specified peer id.
254    pub fn contains(&self, peer_id: &PeerId) -> bool {
255        self.items.contains_key(peer_id)
256    }
257
258    /// Returns the peer resolver handle for the specified peer id, if it exists.
259    pub fn get_handle(&self, peer_id: &PeerId) -> Option<&PeerResolverHandle> {
260        self.items.get(peer_id).map(|item| &item.resolver_handle)
261    }
262
263    /// Adds a peer id to the set.
264    ///
265    /// Returns whether the value was newly inserted.
266    pub fn insert(&mut self, peer_id: &PeerId) -> bool {
267        match self.items.entry(*peer_id) {
268            // No entry for the peer_id, insert a new one
269            indexmap::map::Entry::Vacant(entry) => {
270                let handle = self.peer_resolver.as_ref().map_or_else(
271                    || PeerResolverHandle::new_noop(peer_id),
272                    |resolver| resolver.insert(peer_id, true),
273                );
274
275                entry.insert(PrivateOverlayEntryData {
276                    peer_id: *peer_id,
277                    resolver_handle: handle,
278                });
279
280                _ = self
281                    .events_tx
282                    .send(PrivateOverlayEntriesEvent::Added(*peer_id));
283                true
284            }
285            // Entry for the peer_id exists, do nothing
286            indexmap::map::Entry::Occupied(_) => false,
287        }
288    }
289
290    /// Removes a value from the set.
291    ///
292    /// Returns whether the value was present in the set.
293    pub fn remove(&mut self, peer_id: &PeerId) -> bool {
294        let removed = self.items.swap_remove(peer_id).is_some();
295        if removed {
296            _ = self
297                .events_tx
298                .send(PrivateOverlayEntriesEvent::Removed(*peer_id));
299        }
300        removed
301    }
302}
303
304#[derive(Clone)]
305pub struct PrivateOverlayEntryData {
306    pub peer_id: PeerId,
307    pub resolver_handle: PeerResolverHandle,
308}
309
310pub struct PrivateOverlayEntriesWriteGuard<'a> {
311    entries: RwLockWriteGuard<'a, PrivateOverlayEntries>,
312}
313
314impl std::ops::Deref for PrivateOverlayEntriesWriteGuard<'_> {
315    type Target = PrivateOverlayEntries;
316
317    #[inline]
318    fn deref(&self) -> &Self::Target {
319        &self.entries
320    }
321}
322
323impl std::ops::DerefMut for PrivateOverlayEntriesWriteGuard<'_> {
324    #[inline]
325    fn deref_mut(&mut self) -> &mut Self::Target {
326        &mut self.entries
327    }
328}
329
330impl<'a> PrivateOverlayEntriesWriteGuard<'a> {
331    pub fn downgrade(self) -> PrivateOverlayEntriesReadGuard<'a> {
332        let entries = RwLockWriteGuard::downgrade(self.entries);
333        PrivateOverlayEntriesReadGuard { entries }
334    }
335}
336
337pub struct PrivateOverlayEntriesReadGuard<'a> {
338    entries: RwLockReadGuard<'a, PrivateOverlayEntries>,
339}
340
341impl std::ops::Deref for PrivateOverlayEntriesReadGuard<'_> {
342    type Target = PrivateOverlayEntries;
343
344    #[inline]
345    fn deref(&self) -> &Self::Target {
346        &self.entries
347    }
348}
349
350#[derive(Debug, Clone, PartialEq, Eq)]
351pub enum PrivateOverlayEntriesEvent {
352    /// A new entry was inserted.
353    Added(PeerId),
354    /// An existing entry was removed.
355    Removed(PeerId),
356}
357
358pub struct ChooseMultiplePrivateOverlayEntries<'a> {
359    items: &'a OverlayItems,
360    indices: rand::seq::index::IndexVecIntoIter,
361}
362
363impl<'a> Iterator for ChooseMultiplePrivateOverlayEntries<'a> {
364    type Item = &'a PrivateOverlayEntryData;
365
366    fn next(&mut self) -> Option<Self::Item> {
367        self.indices.next().and_then(|i| {
368            let (_, value) = self.items.get_index(i)?;
369            Some(value)
370        })
371    }
372
373    fn size_hint(&self) -> (usize, Option<usize>) {
374        (self.indices.len(), Some(self.indices.len()))
375    }
376}
377
378impl ExactSizeIterator for ChooseMultiplePrivateOverlayEntries<'_> {
379    fn len(&self) -> usize {
380        self.indices.len()
381    }
382}
383
384type OverlayItems = IndexMap<PeerId, PrivateOverlayEntryData, FastHasherState>;
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    #[test]
391    fn entries_container_is_set() {
392        let mut entries = PrivateOverlayEntries {
393            items: Default::default(),
394            peer_resolver: None,
395            events_tx: broadcast::channel(100).0,
396        };
397        assert!(entries.is_empty());
398        assert_eq!(entries.len(), 0);
399
400        let peer_id = rand::random();
401        assert!(entries.insert(&peer_id));
402
403        assert!(!entries.is_empty());
404        assert_eq!(entries.len(), 1);
405
406        assert!(!entries.insert(&peer_id));
407        assert_eq!(entries.len(), 1);
408
409        entries.clear();
410        assert!(entries.is_empty());
411        assert_eq!(entries.len(), 0);
412    }
413
414    #[test]
415    fn remove_from_entries_container() {
416        let (events_tx, mut events_rx) = broadcast::channel(100);
417
418        let mut entries = PrivateOverlayEntries {
419            items: Default::default(),
420            peer_resolver: None,
421            events_tx,
422        };
423
424        let peer_ids = std::array::from_fn::<PeerId, 10, _>(|_| rand::random());
425        for (i, peer_id) in peer_ids.iter().enumerate() {
426            assert!(entries.insert(peer_id));
427            assert_eq!(entries.len(), i + 1);
428            assert_eq!(
429                events_rx.try_recv().unwrap(),
430                PrivateOverlayEntriesEvent::Added(*peer_id)
431            );
432        }
433
434        for peer_id in &peer_ids {
435            assert!(entries.remove(peer_id));
436            assert_eq!(
437                events_rx.try_recv().unwrap(),
438                PrivateOverlayEntriesEvent::Removed(*peer_id)
439            );
440
441            assert!(!entries.items.contains_key(peer_id));
442        }
443
444        assert!(entries.is_empty());
445
446        assert!(!entries.remove(&rand::random()));
447        assert!(events_rx.try_recv().is_err());
448    }
449}