1use std::borrow::Borrow;
2use std::sync::Arc;
3
4use anyhow::Result;
5use indexmap::IndexMap;
6use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
7use rand::Rng;
8use tokio::sync::broadcast;
9use tycho_util::futures::BoxFutureOrNoop;
10use tycho_util::{FastHashSet, FastHasherState};
11
12use crate::PrefixedRequest;
13use crate::dht::{PeerResolver, PeerResolverHandle};
14use crate::network::Network;
15use crate::overlay::OverlayId;
16use crate::overlay::metrics::Metrics;
17use crate::proto::overlay::rpc;
18use crate::types::{BoxService, PeerId, Response, Service, ServiceExt, ServiceRequest};
19use crate::util::NetworkExt;
20
21pub struct PrivateOverlayBuilder {
22 overlay_id: OverlayId,
23 entries: FastHashSet<PeerId>,
24 entry_events_channel_size: usize,
25 peer_resolver: Option<PeerResolver>,
26 name: Option<&'static str>,
27}
28
29impl PrivateOverlayBuilder {
30 pub fn with_entries<I>(mut self, allowed_peers: I) -> Self
31 where
32 I: IntoIterator,
33 I::Item: Borrow<PeerId>,
34 {
35 self.entries
36 .extend(allowed_peers.into_iter().map(|p| *p.borrow()));
37 self
38 }
39
40 pub fn with_entry_events_channel_size(mut self, entry_events_channel_size: usize) -> Self {
44 self.entry_events_channel_size = entry_events_channel_size;
45 self
46 }
47
48 pub fn with_peer_resolver(mut self, peer_resolver: PeerResolver) -> Self {
52 self.peer_resolver = Some(peer_resolver);
53 self
54 }
55
56 pub fn named(mut self, name: &'static str) -> Self {
58 self.name = Some(name);
59 self
60 }
61
62 pub fn build<S>(self, service: S) -> PrivateOverlay
63 where
64 S: Send + Sync + 'static,
65 S: Service<ServiceRequest, QueryResponse = Response>,
66 {
67 let request_prefix = tl_proto::serialize(rpc::Prefix {
68 overlay_id: self.overlay_id.as_bytes(),
69 });
70
71 let mut entries = PrivateOverlayEntries {
72 items: Default::default(),
73 events_tx: broadcast::channel(self.entry_events_channel_size).0,
74 peer_resolver: self.peer_resolver,
75 };
76 for peer_id in self.entries {
77 entries.insert(&peer_id);
78 }
79
80 PrivateOverlay {
81 inner: Arc::new(Inner {
82 overlay_id: self.overlay_id,
83 entries: RwLock::new(entries),
84 service: service.boxed(),
85 request_prefix: request_prefix.into_boxed_slice(),
86 metrics: self
87 .name
88 .map(|label| Metrics::new("tycho_private_overlay", label))
89 .unwrap_or_default(),
90 }),
91 }
92 }
93}
94
95#[derive(Clone)]
96pub struct PrivateOverlay {
97 inner: Arc<Inner>,
98}
99
100impl PrivateOverlay {
101 pub fn builder(overlay_id: OverlayId) -> PrivateOverlayBuilder {
102 PrivateOverlayBuilder {
103 overlay_id,
104 entries: Default::default(),
105 entry_events_channel_size: 100,
106 peer_resolver: None,
107 name: None,
108 }
109 }
110
111 #[inline]
112 pub fn overlay_id(&self) -> &OverlayId {
113 &self.inner.overlay_id
114 }
115
116 pub fn request_from_tl<T>(&self, body: T) -> PrefixedRequest
117 where
118 T: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
119 {
120 PrefixedRequest::from_tl(&self.inner.request_prefix, body)
121 }
122
123 pub async fn query(
124 &self,
125 network: &Network,
126 peer_id: &PeerId,
127 request: PrefixedRequest,
128 ) -> Result<Response> {
129 self.inner.metrics.record_rx(request.body_len());
130 network.query(peer_id, request.into()).await
131 }
132
133 pub async fn send(
134 &self,
135 network: &Network,
136 peer_id: &PeerId,
137 request: PrefixedRequest,
138 ) -> Result<()> {
139 self.inner.metrics.record_rx(request.body_len());
140 network.send(peer_id, request.into()).await
141 }
142
143 pub fn write_entries(&self) -> PrivateOverlayEntriesWriteGuard<'_> {
144 PrivateOverlayEntriesWriteGuard {
145 entries: self.inner.entries.write(),
146 }
147 }
148
149 pub fn read_entries(&self) -> PrivateOverlayEntriesReadGuard<'_> {
150 PrivateOverlayEntriesReadGuard {
151 entries: self.inner.entries.read(),
152 }
153 }
154
155 pub(crate) fn handle_query(&self, req: ServiceRequest) -> BoxFutureOrNoop<Option<Response>> {
156 self.inner.metrics.record_rx(req.body.len());
157 if self.inner.entries.read().contains(&req.metadata.peer_id) {
158 BoxFutureOrNoop::future(self.inner.service.on_query(req))
159 } else {
160 BoxFutureOrNoop::Noop
161 }
162 }
163
164 pub(crate) fn handle_message(&self, req: ServiceRequest) -> BoxFutureOrNoop<()> {
165 self.inner.metrics.record_rx(req.body.len());
166 if self.inner.entries.read().contains(&req.metadata.peer_id) {
167 BoxFutureOrNoop::future(self.inner.service.on_message(req))
168 } else {
169 BoxFutureOrNoop::Noop
170 }
171 }
172}
173
174impl std::fmt::Debug for PrivateOverlay {
175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176 f.debug_struct("PrivateOverlay")
177 .field("overlay_id", &self.inner.overlay_id)
178 .finish()
179 }
180}
181
182struct Inner {
183 overlay_id: OverlayId,
184 entries: RwLock<PrivateOverlayEntries>,
185 service: BoxService<ServiceRequest, Response>,
186 request_prefix: Box<[u8]>,
187 metrics: Metrics,
188}
189
190pub struct PrivateOverlayEntries {
193 items: OverlayItems,
194 events_tx: broadcast::Sender<PrivateOverlayEntriesEvent>,
195 peer_resolver: Option<PeerResolver>,
196}
197
198impl PrivateOverlayEntries {
199 pub fn subscribe(&self) -> broadcast::Receiver<PrivateOverlayEntriesEvent> {
201 self.events_tx.subscribe()
202 }
203
204 pub fn iter(&self) -> indexmap::map::Values<'_, PeerId, PrivateOverlayEntryData> {
208 self.items.values()
209 }
210
211 pub fn choose<R>(&self, rng: &mut R) -> Option<&PrivateOverlayEntryData>
213 where
214 R: Rng + ?Sized,
215 {
216 let index = rng.random_range(0..self.items.len());
217 let (_, value) = self.items.get_index(index)?;
218 Some(value)
219 }
220
221 pub fn choose_multiple<R>(
224 &self,
225 rng: &mut R,
226 n: usize,
227 ) -> ChooseMultiplePrivateOverlayEntries<'_>
228 where
229 R: Rng + ?Sized,
230 {
231 let len = self.items.len();
232 ChooseMultiplePrivateOverlayEntries {
233 items: &self.items,
234 indices: rand::seq::index::sample(rng, len, n.min(len)).into_iter(),
235 }
236 }
237
238 pub fn clear(&mut self) {
240 self.items.clear();
241 }
242
243 pub fn is_empty(&self) -> bool {
245 self.items.is_empty()
246 }
247
248 pub fn len(&self) -> usize {
250 self.items.len()
251 }
252
253 pub fn contains(&self, peer_id: &PeerId) -> bool {
255 self.items.contains_key(peer_id)
256 }
257
258 pub fn get_handle(&self, peer_id: &PeerId) -> Option<&PeerResolverHandle> {
260 self.items.get(peer_id).map(|item| &item.resolver_handle)
261 }
262
263 pub fn insert(&mut self, peer_id: &PeerId) -> bool {
267 match self.items.entry(*peer_id) {
268 indexmap::map::Entry::Vacant(entry) => {
270 let handle = self.peer_resolver.as_ref().map_or_else(
271 || PeerResolverHandle::new_noop(peer_id),
272 |resolver| resolver.insert(peer_id, true),
273 );
274
275 entry.insert(PrivateOverlayEntryData {
276 peer_id: *peer_id,
277 resolver_handle: handle,
278 });
279
280 _ = self
281 .events_tx
282 .send(PrivateOverlayEntriesEvent::Added(*peer_id));
283 true
284 }
285 indexmap::map::Entry::Occupied(_) => false,
287 }
288 }
289
290 pub fn remove(&mut self, peer_id: &PeerId) -> bool {
294 let removed = self.items.swap_remove(peer_id).is_some();
295 if removed {
296 _ = self
297 .events_tx
298 .send(PrivateOverlayEntriesEvent::Removed(*peer_id));
299 }
300 removed
301 }
302}
303
304#[derive(Clone)]
305pub struct PrivateOverlayEntryData {
306 pub peer_id: PeerId,
307 pub resolver_handle: PeerResolverHandle,
308}
309
310pub struct PrivateOverlayEntriesWriteGuard<'a> {
311 entries: RwLockWriteGuard<'a, PrivateOverlayEntries>,
312}
313
314impl std::ops::Deref for PrivateOverlayEntriesWriteGuard<'_> {
315 type Target = PrivateOverlayEntries;
316
317 #[inline]
318 fn deref(&self) -> &Self::Target {
319 &self.entries
320 }
321}
322
323impl std::ops::DerefMut for PrivateOverlayEntriesWriteGuard<'_> {
324 #[inline]
325 fn deref_mut(&mut self) -> &mut Self::Target {
326 &mut self.entries
327 }
328}
329
330impl<'a> PrivateOverlayEntriesWriteGuard<'a> {
331 pub fn downgrade(self) -> PrivateOverlayEntriesReadGuard<'a> {
332 let entries = RwLockWriteGuard::downgrade(self.entries);
333 PrivateOverlayEntriesReadGuard { entries }
334 }
335}
336
337pub struct PrivateOverlayEntriesReadGuard<'a> {
338 entries: RwLockReadGuard<'a, PrivateOverlayEntries>,
339}
340
341impl std::ops::Deref for PrivateOverlayEntriesReadGuard<'_> {
342 type Target = PrivateOverlayEntries;
343
344 #[inline]
345 fn deref(&self) -> &Self::Target {
346 &self.entries
347 }
348}
349
350#[derive(Debug, Clone, PartialEq, Eq)]
351pub enum PrivateOverlayEntriesEvent {
352 Added(PeerId),
354 Removed(PeerId),
356}
357
358pub struct ChooseMultiplePrivateOverlayEntries<'a> {
359 items: &'a OverlayItems,
360 indices: rand::seq::index::IndexVecIntoIter,
361}
362
363impl<'a> Iterator for ChooseMultiplePrivateOverlayEntries<'a> {
364 type Item = &'a PrivateOverlayEntryData;
365
366 fn next(&mut self) -> Option<Self::Item> {
367 self.indices.next().and_then(|i| {
368 let (_, value) = self.items.get_index(i)?;
369 Some(value)
370 })
371 }
372
373 fn size_hint(&self) -> (usize, Option<usize>) {
374 (self.indices.len(), Some(self.indices.len()))
375 }
376}
377
378impl ExactSizeIterator for ChooseMultiplePrivateOverlayEntries<'_> {
379 fn len(&self) -> usize {
380 self.indices.len()
381 }
382}
383
384type OverlayItems = IndexMap<PeerId, PrivateOverlayEntryData, FastHasherState>;
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn entries_container_is_set() {
392 let mut entries = PrivateOverlayEntries {
393 items: Default::default(),
394 peer_resolver: None,
395 events_tx: broadcast::channel(100).0,
396 };
397 assert!(entries.is_empty());
398 assert_eq!(entries.len(), 0);
399
400 let peer_id = rand::random();
401 assert!(entries.insert(&peer_id));
402
403 assert!(!entries.is_empty());
404 assert_eq!(entries.len(), 1);
405
406 assert!(!entries.insert(&peer_id));
407 assert_eq!(entries.len(), 1);
408
409 entries.clear();
410 assert!(entries.is_empty());
411 assert_eq!(entries.len(), 0);
412 }
413
414 #[test]
415 fn remove_from_entries_container() {
416 let (events_tx, mut events_rx) = broadcast::channel(100);
417
418 let mut entries = PrivateOverlayEntries {
419 items: Default::default(),
420 peer_resolver: None,
421 events_tx,
422 };
423
424 let peer_ids = std::array::from_fn::<PeerId, 10, _>(|_| rand::random());
425 for (i, peer_id) in peer_ids.iter().enumerate() {
426 assert!(entries.insert(peer_id));
427 assert_eq!(entries.len(), i + 1);
428 assert_eq!(
429 events_rx.try_recv().unwrap(),
430 PrivateOverlayEntriesEvent::Added(*peer_id)
431 );
432 }
433
434 for peer_id in &peer_ids {
435 assert!(entries.remove(peer_id));
436 assert_eq!(
437 events_rx.try_recv().unwrap(),
438 PrivateOverlayEntriesEvent::Removed(*peer_id)
439 );
440
441 assert!(!entries.items.contains_key(peer_id));
442 }
443
444 assert!(entries.is_empty());
445
446 assert!(!entries.remove(&rand::random()));
447 assert!(events_rx.try_recv().is_err());
448 }
449}