1use std::sync::Arc;
2
3use bytes::Buf;
4use tl_proto::{TlError, TlRead};
5use tokio::sync::Notify;
6use tycho_util::futures::BoxFutureOrNoop;
7use tycho_util::time::now_sec;
8use tycho_util::{FastDashMap, FastHashMap, FastHashSet};
9
10pub use self::config::OverlayConfig;
11use self::entries_merger::PublicOverlayEntriesMerger;
12pub use self::overlay_id::OverlayId;
13pub use self::private_overlay::{
14 ChooseMultiplePrivateOverlayEntries, PrivateOverlay, PrivateOverlayBuilder,
15 PrivateOverlayEntries, PrivateOverlayEntriesEvent, PrivateOverlayEntriesReadGuard,
16 PrivateOverlayEntriesWriteGuard, PrivateOverlayEntryData,
17};
18pub use self::public_overlay::{
19 ChooseMultiplePublicOverlayEntries, PublicOverlay, PublicOverlayBuilder, PublicOverlayEntries,
20 PublicOverlayEntriesReadGuard, PublicOverlayEntryData, UnknownPeersQueue,
21};
22use crate::dht::DhtService;
23use crate::network::Network;
24use crate::proto::overlay::{PublicEntriesResponse, PublicEntry, PublicEntryResponse, rpc};
25use crate::types::{PeerId, Response, Service, ServiceRequest};
26use crate::util::Routable;
27
28mod background_tasks;
29mod config;
30mod entries_merger;
31mod metrics;
32mod overlay_id;
33mod private_overlay;
34mod public_overlay;
35mod tasks_stream;
36
37pub struct OverlayServiceBackgroundTasks {
38 inner: Arc<OverlayServiceInner>,
39 dht: Option<DhtService>,
40}
41
42impl OverlayServiceBackgroundTasks {
43 pub fn spawn(self, network: &Network) {
44 self.inner
45 .start_background_tasks(Network::downgrade(network), self.dht);
46 }
47}
48
49pub struct OverlayServiceBuilder {
50 local_id: PeerId,
51 config: Option<OverlayConfig>,
52 dht: Option<DhtService>,
53}
54
55impl OverlayServiceBuilder {
56 pub fn with_config(mut self, config: OverlayConfig) -> Self {
57 self.config = Some(config);
58 self
59 }
60
61 pub fn with_dht_service(mut self, dht: DhtService) -> Self {
62 self.dht = Some(dht);
63 self
64 }
65
66 pub fn build(self) -> (OverlayServiceBackgroundTasks, OverlayService) {
67 let config = self.config.unwrap_or_default();
68
69 let inner = Arc::new(OverlayServiceInner {
70 local_id: self.local_id,
71 config,
72 private_overlays: Default::default(),
73 public_overlays: Default::default(),
74 public_overlays_changed: Arc::new(Notify::new()),
75 private_overlays_changed: Arc::new(Notify::new()),
76 public_entries_merger: Arc::new(PublicOverlayEntriesMerger),
77 });
78
79 let background_tasks = OverlayServiceBackgroundTasks {
80 inner: inner.clone(),
81 dht: self.dht,
82 };
83
84 (background_tasks, OverlayService(inner))
85 }
86}
87
88#[derive(Clone)]
89pub struct OverlayService(Arc<OverlayServiceInner>);
90
91impl OverlayService {
92 pub fn builder(local_id: PeerId) -> OverlayServiceBuilder {
93 OverlayServiceBuilder {
94 local_id,
95 config: None,
96 dht: None,
97 }
98 }
99
100 pub fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool {
101 self.0.add_private_overlay(overlay)
102 }
103
104 pub fn remove_private_overlay(&self, overlay_id: &OverlayId) -> bool {
105 self.0.remove_private_overlay(overlay_id)
106 }
107
108 pub fn add_public_overlay(&self, overlay: &PublicOverlay) -> bool {
109 self.0.add_public_overlay(overlay)
110 }
111
112 pub fn remove_public_overlay(&self, overlay_id: &OverlayId) -> bool {
113 self.0.remove_public_overlay(overlay_id)
114 }
115
116 pub fn public_overlays(&self) -> FastHashMap<OverlayId, PublicOverlay> {
117 self.0
118 .public_overlays
119 .iter()
120 .map(|item| (*item.key(), item.value().clone()))
121 .collect()
122 }
123
124 pub fn private_overlays(&self) -> FastHashMap<OverlayId, PrivateOverlay> {
125 self.0
126 .private_overlays
127 .iter()
128 .map(|item| (*item.key(), item.value().clone()))
129 .collect()
130 }
131}
132
133impl Service<ServiceRequest> for OverlayService {
134 type QueryResponse = Response;
135 type OnQueryFuture = BoxFutureOrNoop<Option<Self::QueryResponse>>;
136 type OnMessageFuture = BoxFutureOrNoop<()>;
137
138 #[tracing::instrument(
139 level = "debug",
140 name = "on_overlay_query",
141 skip_all,
142 fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
143 )]
144 fn on_query(&self, mut req: ServiceRequest) -> Self::OnQueryFuture {
145 let e = 'req: {
146 let mut req_body = req.body.as_ref();
147 if req_body.len() < 4 {
148 break 'req TlError::UnexpectedEof;
149 }
150
151 let overlay_id = match std::convert::identity(req_body).get_u32_le() {
152 rpc::Prefix::TL_ID => match rpc::Prefix::read_from(&mut req_body) {
153 Ok(rpc::Prefix { overlay_id }) => overlay_id,
154 Err(e) => break 'req e,
155 },
156 rpc::ExchangeRandomPublicEntries::TL_ID => {
157 let req = match tl_proto::deserialize::<rpc::ExchangeRandomPublicEntries>(
158 &req.body,
159 ) {
160 Ok(req) => req,
161 Err(e) => break 'req e,
162 };
163 tracing::debug!("exchange_random_public_entries");
164
165 let res = self.0.handle_exchange_public_entries(&req);
166 return BoxFutureOrNoop::future(futures_util::future::ready(Some(
167 Response::from_tl(res),
168 )));
169 }
170 rpc::GetPublicEntry::TL_ID => {
171 let req = match tl_proto::deserialize::<rpc::GetPublicEntry>(&req.body) {
172 Ok(req) => req,
173 Err(e) => break 'req e,
174 };
175 tracing::debug!("get_public_entry");
176
177 let res = self.0.handle_get_public_entry(&req);
178 return BoxFutureOrNoop::future(futures_util::future::ready(Some(
179 Response::from_tl(res),
180 )));
181 }
182 _ => break 'req TlError::UnknownConstructor,
183 };
184
185 if req_body.len() < 4 {
186 break 'req TlError::UnexpectedEof;
188 }
189 let offset = req.body.len() - req_body.len();
190
191 if let Some(private_overlay) = self.0.private_overlays.get(overlay_id) {
192 req.body.advance(offset);
193 return private_overlay.handle_query(req);
194 } else if let Some(public_overlay) = self.0.public_overlays.get(overlay_id) {
195 req.body.advance(offset);
196 return public_overlay.handle_query(req);
197 }
198
199 tracing::debug!(
200 overlay_id = %OverlayId::wrap(overlay_id),
201 "unknown overlay id"
202 );
203 return BoxFutureOrNoop::Noop;
204 };
205
206 tracing::debug!("failed to deserialize query: {e:?}");
207 BoxFutureOrNoop::Noop
208 }
209
210 #[tracing::instrument(
211 level = "debug",
212 name = "on_overlay_message",
213 skip_all,
214 fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
215 )]
216 fn on_message(&self, mut req: ServiceRequest) -> Self::OnMessageFuture {
217 let e = 'req: {
220 let mut req_body = req.body.as_ref();
221 if req_body.len() < 4 {
222 break 'req TlError::UnexpectedEof;
223 }
224
225 let overlay_id = match std::convert::identity(req_body).get_u32_le() {
226 rpc::Prefix::TL_ID => match rpc::Prefix::read_from(&mut req_body) {
227 Ok(rpc::Prefix { overlay_id }) => overlay_id,
228 Err(e) => break 'req e,
229 },
230 _ => break 'req TlError::UnknownConstructor,
231 };
232
233 if req_body.len() < 4 {
234 break 'req TlError::UnexpectedEof;
236 }
237 let offset = req.body.len() - req_body.len();
238
239 if let Some(private_overlay) = self.0.private_overlays.get(overlay_id) {
240 req.body.advance(offset);
241 return private_overlay.handle_message(req);
242 } else if let Some(public_overlay) = self.0.public_overlays.get(overlay_id) {
243 req.body.advance(offset);
244 return public_overlay.handle_message(req);
245 }
246
247 tracing::debug!(
248 overlay_id = %OverlayId::wrap(overlay_id),
249 "unknown overlay id"
250 );
251 return BoxFutureOrNoop::Noop;
252 };
253
254 tracing::debug!("failed to deserialize message: {e:?}");
255 BoxFutureOrNoop::Noop
256 }
257}
258
259impl Routable for OverlayService {
260 fn query_ids(&self) -> impl IntoIterator<Item = u32> {
261 [
262 rpc::ExchangeRandomPublicEntries::TL_ID,
263 rpc::GetPublicEntry::TL_ID,
264 rpc::Prefix::TL_ID,
265 ]
266 }
267
268 fn message_ids(&self) -> impl IntoIterator<Item = u32> {
269 [rpc::Prefix::TL_ID]
270 }
271}
272
273struct OverlayServiceInner {
274 local_id: PeerId,
275 config: OverlayConfig,
276 public_overlays: FastDashMap<OverlayId, PublicOverlay>,
277 private_overlays: FastDashMap<OverlayId, PrivateOverlay>,
278 public_overlays_changed: Arc<Notify>,
279 private_overlays_changed: Arc<Notify>,
280 public_entries_merger: Arc<PublicOverlayEntriesMerger>,
281}
282
283impl OverlayServiceInner {
284 fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool {
285 use dashmap::mapref::entry::Entry;
286
287 if self.public_overlays.contains_key(overlay.overlay_id()) {
288 return false;
289 }
290 match self.private_overlays.entry(*overlay.overlay_id()) {
291 Entry::Vacant(entry) => {
292 entry.insert(overlay.clone());
293 self.private_overlays_changed.notify_waiters();
294 true
295 }
296 Entry::Occupied(_) => false,
297 }
298 }
299
300 fn remove_private_overlay(&self, overlay_id: &OverlayId) -> bool {
301 let removed = self.private_overlays.remove(overlay_id).is_some();
302 if removed {
303 self.private_overlays_changed.notify_waiters();
304 }
305 removed
306 }
307
308 fn add_public_overlay(&self, overlay: &PublicOverlay) -> bool {
309 use dashmap::mapref::entry::Entry;
310
311 if self.private_overlays.contains_key(overlay.overlay_id()) {
312 return false;
313 }
314 match self.public_overlays.entry(*overlay.overlay_id()) {
315 Entry::Vacant(entry) => {
316 entry.insert(overlay.clone());
317 self.public_overlays_changed.notify_waiters();
318 true
319 }
320 Entry::Occupied(_) => false,
321 }
322 }
323
324 fn remove_public_overlay(&self, overlay_id: &OverlayId) -> bool {
325 let removed = self.public_overlays.remove(overlay_id).is_some();
326 if removed {
327 self.public_overlays_changed.notify_waiters();
328 }
329 removed
330 }
331
332 fn handle_exchange_public_entries(
333 &self,
334 req: &rpc::ExchangeRandomPublicEntries,
335 ) -> PublicEntriesResponse {
336 debug_assert!(req.entries.len() <= 20);
338
339 let overlay = match self.public_overlays.get(&req.overlay_id) {
341 Some(overlay) => overlay,
342 None => return PublicEntriesResponse::OverlayNotFound,
343 };
344
345 overlay.add_untrusted_entries(&self.local_id, &req.entries, now_sec());
347
348 let requested_ids = req
350 .entries
351 .iter()
352 .map(|id| id.peer_id)
353 .collect::<FastHashSet<_>>();
354
355 let entries = {
356 let entries = overlay.read_entries();
357
358 let n = self.config.exchange_public_entries_batch;
360 entries
361 .choose_multiple(&mut rand::rng(), n + requested_ids.len())
362 .filter_map(|item| {
363 let is_new = !requested_ids.contains(&item.entry.peer_id);
364 is_new.then(|| item.entry.clone())
365 })
366 .take(n)
367 .collect::<Vec<_>>()
368 };
369
370 PublicEntriesResponse::PublicEntries(entries)
371 }
372
373 fn handle_get_public_entry(&self, req: &rpc::GetPublicEntry) -> PublicEntryResponse {
374 let overlay = match self.public_overlays.get(&req.overlay_id) {
376 Some(overlay) => overlay,
377 None => return PublicEntryResponse::OverlayNotFound,
378 };
379
380 let Some(entry) = overlay.own_signed_entry() else {
381 return PublicEntryResponse::OverlayNotFound;
385 };
386
387 PublicEntryResponse::Found(entry)
388 }
389}