1use std::borrow::Borrow;
2use std::sync::Arc;
3
4use anyhow::Result;
5use bytes::{Bytes, BytesMut};
6use indexmap::IndexMap;
7use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
8use rand::Rng;
9use tokio::sync::broadcast;
10use tycho_util::futures::BoxFutureOrNoop;
11use tycho_util::{FastHashSet, FastHasherState};
12
13use crate::dht::{PeerResolver, PeerResolverHandle};
14use crate::network::Network;
15use crate::overlay::OverlayId;
16use crate::overlay::metrics::Metrics;
17use crate::proto::overlay::rpc;
18use crate::types::{BoxService, PeerId, Request, Response, Service, ServiceExt, ServiceRequest};
19use crate::util::NetworkExt;
20
21pub struct PrivateOverlayBuilder {
22 overlay_id: OverlayId,
23 entries: FastHashSet<PeerId>,
24 entry_events_channel_size: usize,
25 peer_resolver: Option<PeerResolver>,
26 name: Option<&'static str>,
27}
28
29impl PrivateOverlayBuilder {
30 pub fn with_entries<I>(mut self, allowed_peers: I) -> Self
31 where
32 I: IntoIterator,
33 I::Item: Borrow<PeerId>,
34 {
35 self.entries
36 .extend(allowed_peers.into_iter().map(|p| *p.borrow()));
37 self
38 }
39
40 pub fn with_entry_events_channel_size(mut self, entry_events_channel_size: usize) -> Self {
44 self.entry_events_channel_size = entry_events_channel_size;
45 self
46 }
47
48 pub fn with_peer_resolver(mut self, peer_resolver: PeerResolver) -> Self {
52 self.peer_resolver = Some(peer_resolver);
53 self
54 }
55
56 pub fn named(mut self, name: &'static str) -> Self {
58 self.name = Some(name);
59 self
60 }
61
62 pub fn build<S>(self, service: S) -> PrivateOverlay
63 where
64 S: Send + Sync + 'static,
65 S: Service<ServiceRequest, QueryResponse = Response>,
66 {
67 let request_prefix = tl_proto::serialize(rpc::Prefix {
68 overlay_id: self.overlay_id.as_bytes(),
69 });
70
71 let mut entries = PrivateOverlayEntries {
72 items: Default::default(),
73 events_tx: broadcast::channel(self.entry_events_channel_size).0,
74 peer_resolver: self.peer_resolver,
75 };
76 for peer_id in self.entries {
77 entries.insert(&peer_id);
78 }
79
80 PrivateOverlay {
81 inner: Arc::new(Inner {
82 overlay_id: self.overlay_id,
83 entries: RwLock::new(entries),
84 service: service.boxed(),
85 request_prefix: request_prefix.into_boxed_slice(),
86 metrics: self
87 .name
88 .map(|label| Metrics::new("tycho_private_overlay", label))
89 .unwrap_or_default(),
90 }),
91 }
92 }
93}
94
95#[derive(Clone)]
96pub struct PrivateOverlay {
97 inner: Arc<Inner>,
98}
99
100impl PrivateOverlay {
101 pub fn builder(overlay_id: OverlayId) -> PrivateOverlayBuilder {
102 PrivateOverlayBuilder {
103 overlay_id,
104 entries: Default::default(),
105 entry_events_channel_size: 100,
106 peer_resolver: None,
107 name: None,
108 }
109 }
110
111 #[inline]
112 pub fn overlay_id(&self) -> &OverlayId {
113 &self.inner.overlay_id
114 }
115
116 pub async fn query(
117 &self,
118 network: &Network,
119 peer_id: &PeerId,
120 mut request: Request,
121 ) -> Result<Response> {
122 self.inner.metrics.record_rx(request.body.len());
123 self.prepend_prefix_to_body(&mut request.body);
124 network.query(peer_id, request).await
125 }
126
127 pub async fn send(
128 &self,
129 network: &Network,
130 peer_id: &PeerId,
131 mut request: Request,
132 ) -> Result<()> {
133 self.inner.metrics.record_rx(request.body.len());
134 self.prepend_prefix_to_body(&mut request.body);
135 network.send(peer_id, request).await
136 }
137
138 pub fn write_entries(&self) -> PrivateOverlayEntriesWriteGuard<'_> {
139 PrivateOverlayEntriesWriteGuard {
140 entries: self.inner.entries.write(),
141 }
142 }
143
144 pub fn read_entries(&self) -> PrivateOverlayEntriesReadGuard<'_> {
145 PrivateOverlayEntriesReadGuard {
146 entries: self.inner.entries.read(),
147 }
148 }
149
150 pub(crate) fn handle_query(&self, req: ServiceRequest) -> BoxFutureOrNoop<Option<Response>> {
151 self.inner.metrics.record_rx(req.body.len());
152 if self.inner.entries.read().contains(&req.metadata.peer_id) {
153 BoxFutureOrNoop::future(self.inner.service.on_query(req))
154 } else {
155 BoxFutureOrNoop::Noop
156 }
157 }
158
159 pub(crate) fn handle_message(&self, req: ServiceRequest) -> BoxFutureOrNoop<()> {
160 self.inner.metrics.record_rx(req.body.len());
161 if self.inner.entries.read().contains(&req.metadata.peer_id) {
162 BoxFutureOrNoop::future(self.inner.service.on_message(req))
163 } else {
164 BoxFutureOrNoop::Noop
165 }
166 }
167
168 fn prepend_prefix_to_body(&self, body: &mut Bytes) {
169 let mut res = BytesMut::with_capacity(self.inner.request_prefix.len() + body.len());
171 res.extend_from_slice(&self.inner.request_prefix);
172 res.extend_from_slice(body);
173 *body = res.freeze();
174 }
175}
176
177impl std::fmt::Debug for PrivateOverlay {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("PrivateOverlay")
180 .field("overlay_id", &self.inner.overlay_id)
181 .finish()
182 }
183}
184
185struct Inner {
186 overlay_id: OverlayId,
187 entries: RwLock<PrivateOverlayEntries>,
188 service: BoxService<ServiceRequest, Response>,
189 request_prefix: Box<[u8]>,
190 metrics: Metrics,
191}
192
193pub struct PrivateOverlayEntries {
196 items: OverlayItems,
197 events_tx: broadcast::Sender<PrivateOverlayEntriesEvent>,
198 peer_resolver: Option<PeerResolver>,
199}
200
201impl PrivateOverlayEntries {
202 pub fn subscribe(&self) -> broadcast::Receiver<PrivateOverlayEntriesEvent> {
204 self.events_tx.subscribe()
205 }
206
207 pub fn iter(&self) -> indexmap::map::Values<'_, PeerId, PrivateOverlayEntryData> {
211 self.items.values()
212 }
213
214 pub fn choose<R>(&self, rng: &mut R) -> Option<&PrivateOverlayEntryData>
216 where
217 R: Rng + ?Sized,
218 {
219 let index = rng.random_range(0..self.items.len());
220 let (_, value) = self.items.get_index(index)?;
221 Some(value)
222 }
223
224 pub fn choose_multiple<R>(
227 &self,
228 rng: &mut R,
229 n: usize,
230 ) -> ChooseMultiplePrivateOverlayEntries<'_>
231 where
232 R: Rng + ?Sized,
233 {
234 let len = self.items.len();
235 ChooseMultiplePrivateOverlayEntries {
236 items: &self.items,
237 indices: rand::seq::index::sample(rng, len, n.min(len)).into_iter(),
238 }
239 }
240
241 pub fn clear(&mut self) {
243 self.items.clear();
244 }
245
246 pub fn is_empty(&self) -> bool {
248 self.items.is_empty()
249 }
250
251 pub fn len(&self) -> usize {
253 self.items.len()
254 }
255
256 pub fn contains(&self, peer_id: &PeerId) -> bool {
258 self.items.contains_key(peer_id)
259 }
260
261 pub fn get_handle(&self, peer_id: &PeerId) -> Option<&PeerResolverHandle> {
263 self.items.get(peer_id).map(|item| &item.resolver_handle)
264 }
265
266 pub fn insert(&mut self, peer_id: &PeerId) -> bool {
270 match self.items.entry(*peer_id) {
271 indexmap::map::Entry::Vacant(entry) => {
273 let handle = self.peer_resolver.as_ref().map_or_else(
274 || PeerResolverHandle::new_noop(peer_id),
275 |resolver| resolver.insert(peer_id, true),
276 );
277
278 entry.insert(PrivateOverlayEntryData {
279 peer_id: *peer_id,
280 resolver_handle: handle,
281 });
282
283 _ = self
284 .events_tx
285 .send(PrivateOverlayEntriesEvent::Added(*peer_id));
286 true
287 }
288 indexmap::map::Entry::Occupied(_) => false,
290 }
291 }
292
293 pub fn remove(&mut self, peer_id: &PeerId) -> bool {
297 let removed = self.items.swap_remove(peer_id).is_some();
298 if removed {
299 _ = self
300 .events_tx
301 .send(PrivateOverlayEntriesEvent::Removed(*peer_id));
302 }
303 removed
304 }
305}
306
307#[derive(Clone)]
308pub struct PrivateOverlayEntryData {
309 pub peer_id: PeerId,
310 pub resolver_handle: PeerResolverHandle,
311}
312
313pub struct PrivateOverlayEntriesWriteGuard<'a> {
314 entries: RwLockWriteGuard<'a, PrivateOverlayEntries>,
315}
316
317impl std::ops::Deref for PrivateOverlayEntriesWriteGuard<'_> {
318 type Target = PrivateOverlayEntries;
319
320 #[inline]
321 fn deref(&self) -> &Self::Target {
322 &self.entries
323 }
324}
325
326impl std::ops::DerefMut for PrivateOverlayEntriesWriteGuard<'_> {
327 #[inline]
328 fn deref_mut(&mut self) -> &mut Self::Target {
329 &mut self.entries
330 }
331}
332
333impl<'a> PrivateOverlayEntriesWriteGuard<'a> {
334 pub fn downgrade(self) -> PrivateOverlayEntriesReadGuard<'a> {
335 let entries = RwLockWriteGuard::downgrade(self.entries);
336 PrivateOverlayEntriesReadGuard { entries }
337 }
338}
339
340pub struct PrivateOverlayEntriesReadGuard<'a> {
341 entries: RwLockReadGuard<'a, PrivateOverlayEntries>,
342}
343
344impl std::ops::Deref for PrivateOverlayEntriesReadGuard<'_> {
345 type Target = PrivateOverlayEntries;
346
347 #[inline]
348 fn deref(&self) -> &Self::Target {
349 &self.entries
350 }
351}
352
353#[derive(Debug, Clone, PartialEq, Eq)]
354pub enum PrivateOverlayEntriesEvent {
355 Added(PeerId),
357 Removed(PeerId),
359}
360
361pub struct ChooseMultiplePrivateOverlayEntries<'a> {
362 items: &'a OverlayItems,
363 indices: rand::seq::index::IndexVecIntoIter,
364}
365
366impl<'a> Iterator for ChooseMultiplePrivateOverlayEntries<'a> {
367 type Item = &'a PrivateOverlayEntryData;
368
369 fn next(&mut self) -> Option<Self::Item> {
370 self.indices.next().and_then(|i| {
371 let (_, value) = self.items.get_index(i)?;
372 Some(value)
373 })
374 }
375
376 fn size_hint(&self) -> (usize, Option<usize>) {
377 (self.indices.len(), Some(self.indices.len()))
378 }
379}
380
381impl ExactSizeIterator for ChooseMultiplePrivateOverlayEntries<'_> {
382 fn len(&self) -> usize {
383 self.indices.len()
384 }
385}
386
387type OverlayItems = IndexMap<PeerId, PrivateOverlayEntryData, FastHasherState>;
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn entries_container_is_set() {
395 let mut entries = PrivateOverlayEntries {
396 items: Default::default(),
397 peer_resolver: None,
398 events_tx: broadcast::channel(100).0,
399 };
400 assert!(entries.is_empty());
401 assert_eq!(entries.len(), 0);
402
403 let peer_id = rand::random();
404 assert!(entries.insert(&peer_id));
405
406 assert!(!entries.is_empty());
407 assert_eq!(entries.len(), 1);
408
409 assert!(!entries.insert(&peer_id));
410 assert_eq!(entries.len(), 1);
411
412 entries.clear();
413 assert!(entries.is_empty());
414 assert_eq!(entries.len(), 0);
415 }
416
417 #[test]
418 fn remove_from_entries_container() {
419 let (events_tx, mut events_rx) = broadcast::channel(100);
420
421 let mut entries = PrivateOverlayEntries {
422 items: Default::default(),
423 peer_resolver: None,
424 events_tx,
425 };
426
427 let peer_ids = std::array::from_fn::<PeerId, 10, _>(|_| rand::random());
428 for (i, peer_id) in peer_ids.iter().enumerate() {
429 assert!(entries.insert(peer_id));
430 assert_eq!(entries.len(), i + 1);
431 assert_eq!(
432 events_rx.try_recv().unwrap(),
433 PrivateOverlayEntriesEvent::Added(*peer_id)
434 );
435 }
436
437 for peer_id in &peer_ids {
438 assert!(entries.remove(peer_id));
439 assert_eq!(
440 events_rx.try_recv().unwrap(),
441 PrivateOverlayEntriesEvent::Removed(*peer_id)
442 );
443
444 assert!(!entries.items.contains_key(peer_id));
445 }
446
447 assert!(entries.is_empty());
448
449 assert!(!entries.remove(&rand::random()));
450 assert!(events_rx.try_recv().is_err());
451 }
452}