1use std::sync::Arc;
2
3use bytes::Buf;
4use tl_proto::{TlError, TlRead};
5use tokio::sync::Notify;
6use tycho_util::futures::BoxFutureOrNoop;
7use tycho_util::time::now_sec;
8use tycho_util::{FastDashMap, FastHashSet};
9
10pub use self::config::OverlayConfig;
11use self::entries_merger::PublicOverlayEntriesMerger;
12pub use self::overlay_id::OverlayId;
13pub use self::private_overlay::{
14 ChooseMultiplePrivateOverlayEntries, PrivateOverlay, PrivateOverlayBuilder,
15 PrivateOverlayEntries, PrivateOverlayEntriesEvent, PrivateOverlayEntriesReadGuard,
16 PrivateOverlayEntriesWriteGuard, PrivateOverlayEntryData,
17};
18pub use self::public_overlay::{
19 ChooseMultiplePublicOverlayEntries, PublicOverlay, PublicOverlayBuilder, PublicOverlayEntries,
20 PublicOverlayEntriesReadGuard, PublicOverlayEntryData, UnknownPeersQueue,
21};
22use crate::dht::DhtService;
23use crate::network::Network;
24use crate::proto::overlay::{PublicEntriesResponse, PublicEntry, PublicEntryResponse, rpc};
25use crate::types::{PeerId, Response, Service, ServiceRequest};
26use crate::util::Routable;
27
28mod background_tasks;
29mod config;
30mod entries_merger;
31mod metrics;
32mod overlay_id;
33mod private_overlay;
34mod public_overlay;
35mod tasks_stream;
36
37pub struct OverlayServiceBackgroundTasks {
38 inner: Arc<OverlayServiceInner>,
39 dht: Option<DhtService>,
40}
41
42impl OverlayServiceBackgroundTasks {
43 pub fn spawn(self, network: &Network) {
44 self.inner
45 .start_background_tasks(Network::downgrade(network), self.dht);
46 }
47}
48
49pub struct OverlayServiceBuilder {
50 local_id: PeerId,
51 config: Option<OverlayConfig>,
52 dht: Option<DhtService>,
53}
54
55impl OverlayServiceBuilder {
56 pub fn with_config(mut self, config: OverlayConfig) -> Self {
57 self.config = Some(config);
58 self
59 }
60
61 pub fn with_dht_service(mut self, dht: DhtService) -> Self {
62 self.dht = Some(dht);
63 self
64 }
65
66 pub fn build(self) -> (OverlayServiceBackgroundTasks, OverlayService) {
67 let config = self.config.unwrap_or_default();
68
69 let inner = Arc::new(OverlayServiceInner {
70 local_id: self.local_id,
71 config,
72 private_overlays: Default::default(),
73 public_overlays: Default::default(),
74 public_overlays_changed: Arc::new(Notify::new()),
75 private_overlays_changed: Arc::new(Notify::new()),
76 public_entries_merger: Arc::new(PublicOverlayEntriesMerger),
77 });
78
79 let background_tasks = OverlayServiceBackgroundTasks {
80 inner: inner.clone(),
81 dht: self.dht,
82 };
83
84 (background_tasks, OverlayService(inner))
85 }
86}
87
88#[derive(Clone)]
89pub struct OverlayService(Arc<OverlayServiceInner>);
90
91impl OverlayService {
92 pub fn builder(local_id: PeerId) -> OverlayServiceBuilder {
93 OverlayServiceBuilder {
94 local_id,
95 config: None,
96 dht: None,
97 }
98 }
99
100 pub fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool {
101 self.0.add_private_overlay(overlay)
102 }
103
104 pub fn remove_private_overlay(&self, overlay_id: &OverlayId) -> bool {
105 self.0.remove_private_overlay(overlay_id)
106 }
107
108 pub fn add_public_overlay(&self, overlay: &PublicOverlay) -> bool {
109 self.0.add_public_overlay(overlay)
110 }
111
112 pub fn remove_public_overlay(&self, overlay_id: &OverlayId) -> bool {
113 self.0.remove_public_overlay(overlay_id)
114 }
115}
116
117impl Service<ServiceRequest> for OverlayService {
118 type QueryResponse = Response;
119 type OnQueryFuture = BoxFutureOrNoop<Option<Self::QueryResponse>>;
120 type OnMessageFuture = BoxFutureOrNoop<()>;
121
122 #[tracing::instrument(
123 level = "debug",
124 name = "on_overlay_query",
125 skip_all,
126 fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
127 )]
128 fn on_query(&self, mut req: ServiceRequest) -> Self::OnQueryFuture {
129 let e = 'req: {
130 let mut req_body = req.body.as_ref();
131 if req_body.len() < 4 {
132 break 'req TlError::UnexpectedEof;
133 }
134
135 let overlay_id = match std::convert::identity(req_body).get_u32_le() {
136 rpc::Prefix::TL_ID => match rpc::Prefix::read_from(&mut req_body) {
137 Ok(rpc::Prefix { overlay_id }) => overlay_id,
138 Err(e) => break 'req e,
139 },
140 rpc::ExchangeRandomPublicEntries::TL_ID => {
141 let req = match tl_proto::deserialize::<rpc::ExchangeRandomPublicEntries>(
142 &req.body,
143 ) {
144 Ok(req) => req,
145 Err(e) => break 'req e,
146 };
147 tracing::debug!("exchange_random_public_entries");
148
149 let res = self.0.handle_exchange_public_entries(&req);
150 return BoxFutureOrNoop::future(futures_util::future::ready(Some(
151 Response::from_tl(res),
152 )));
153 }
154 rpc::GetPublicEntry::TL_ID => {
155 let req = match tl_proto::deserialize::<rpc::GetPublicEntry>(&req.body) {
156 Ok(req) => req,
157 Err(e) => break 'req e,
158 };
159 tracing::debug!("get_public_entry");
160
161 let res = self.0.handle_get_public_entry(&req);
162 return BoxFutureOrNoop::future(futures_util::future::ready(Some(
163 Response::from_tl(res),
164 )));
165 }
166 _ => break 'req TlError::UnknownConstructor,
167 };
168
169 if req_body.len() < 4 {
170 break 'req TlError::UnexpectedEof;
172 }
173 let offset = req.body.len() - req_body.len();
174
175 if let Some(private_overlay) = self.0.private_overlays.get(overlay_id) {
176 req.body.advance(offset);
177 return private_overlay.handle_query(req);
178 } else if let Some(public_overlay) = self.0.public_overlays.get(overlay_id) {
179 req.body.advance(offset);
180 return public_overlay.handle_query(req);
181 }
182
183 tracing::debug!(
184 overlay_id = %OverlayId::wrap(overlay_id),
185 "unknown overlay id"
186 );
187 return BoxFutureOrNoop::Noop;
188 };
189
190 tracing::debug!("failed to deserialize query: {e:?}");
191 BoxFutureOrNoop::Noop
192 }
193
194 #[tracing::instrument(
195 level = "debug",
196 name = "on_overlay_message",
197 skip_all,
198 fields(peer_id = %req.metadata.peer_id, addr = %req.metadata.remote_address)
199 )]
200 fn on_message(&self, mut req: ServiceRequest) -> Self::OnMessageFuture {
201 let e = 'req: {
204 let mut req_body = req.body.as_ref();
205 if req_body.len() < 4 {
206 break 'req TlError::UnexpectedEof;
207 }
208
209 let overlay_id = match std::convert::identity(req_body).get_u32_le() {
210 rpc::Prefix::TL_ID => match rpc::Prefix::read_from(&mut req_body) {
211 Ok(rpc::Prefix { overlay_id }) => overlay_id,
212 Err(e) => break 'req e,
213 },
214 _ => break 'req TlError::UnknownConstructor,
215 };
216
217 if req_body.len() < 4 {
218 break 'req TlError::UnexpectedEof;
220 }
221 let offset = req.body.len() - req_body.len();
222
223 if let Some(private_overlay) = self.0.private_overlays.get(overlay_id) {
224 req.body.advance(offset);
225 return private_overlay.handle_message(req);
226 } else if let Some(public_overlay) = self.0.public_overlays.get(overlay_id) {
227 req.body.advance(offset);
228 return public_overlay.handle_message(req);
229 }
230
231 tracing::debug!(
232 overlay_id = %OverlayId::wrap(overlay_id),
233 "unknown overlay id"
234 );
235 return BoxFutureOrNoop::Noop;
236 };
237
238 tracing::debug!("failed to deserialize message: {e:?}");
239 BoxFutureOrNoop::Noop
240 }
241}
242
243impl Routable for OverlayService {
244 fn query_ids(&self) -> impl IntoIterator<Item = u32> {
245 [
246 rpc::ExchangeRandomPublicEntries::TL_ID,
247 rpc::GetPublicEntry::TL_ID,
248 rpc::Prefix::TL_ID,
249 ]
250 }
251
252 fn message_ids(&self) -> impl IntoIterator<Item = u32> {
253 [rpc::Prefix::TL_ID]
254 }
255}
256
257struct OverlayServiceInner {
258 local_id: PeerId,
259 config: OverlayConfig,
260 public_overlays: FastDashMap<OverlayId, PublicOverlay>,
261 private_overlays: FastDashMap<OverlayId, PrivateOverlay>,
262 public_overlays_changed: Arc<Notify>,
263 private_overlays_changed: Arc<Notify>,
264 public_entries_merger: Arc<PublicOverlayEntriesMerger>,
265}
266
267impl OverlayServiceInner {
268 fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool {
269 use dashmap::mapref::entry::Entry;
270
271 if self.public_overlays.contains_key(overlay.overlay_id()) {
272 return false;
273 }
274 match self.private_overlays.entry(*overlay.overlay_id()) {
275 Entry::Vacant(entry) => {
276 entry.insert(overlay.clone());
277 self.private_overlays_changed.notify_waiters();
278 true
279 }
280 Entry::Occupied(_) => false,
281 }
282 }
283
284 fn remove_private_overlay(&self, overlay_id: &OverlayId) -> bool {
285 let removed = self.private_overlays.remove(overlay_id).is_some();
286 if removed {
287 self.private_overlays_changed.notify_waiters();
288 }
289 removed
290 }
291
292 fn add_public_overlay(&self, overlay: &PublicOverlay) -> bool {
293 use dashmap::mapref::entry::Entry;
294
295 if self.private_overlays.contains_key(overlay.overlay_id()) {
296 return false;
297 }
298 match self.public_overlays.entry(*overlay.overlay_id()) {
299 Entry::Vacant(entry) => {
300 entry.insert(overlay.clone());
301 self.public_overlays_changed.notify_waiters();
302 true
303 }
304 Entry::Occupied(_) => false,
305 }
306 }
307
308 fn remove_public_overlay(&self, overlay_id: &OverlayId) -> bool {
309 let removed = self.public_overlays.remove(overlay_id).is_some();
310 if removed {
311 self.public_overlays_changed.notify_waiters();
312 }
313 removed
314 }
315
316 fn handle_exchange_public_entries(
317 &self,
318 req: &rpc::ExchangeRandomPublicEntries,
319 ) -> PublicEntriesResponse {
320 debug_assert!(req.entries.len() <= 20);
322
323 let overlay = match self.public_overlays.get(&req.overlay_id) {
325 Some(overlay) => overlay,
326 None => return PublicEntriesResponse::OverlayNotFound,
327 };
328
329 overlay.add_untrusted_entries(&self.local_id, &req.entries, now_sec());
331
332 let requested_ids = req
334 .entries
335 .iter()
336 .map(|id| id.peer_id)
337 .collect::<FastHashSet<_>>();
338
339 let entries = {
340 let entries = overlay.read_entries();
341
342 let n = self.config.exchange_public_entries_batch;
344 entries
345 .choose_multiple(&mut rand::rng(), n + requested_ids.len())
346 .filter_map(|item| {
347 let is_new = !requested_ids.contains(&item.entry.peer_id);
348 is_new.then(|| item.entry.clone())
349 })
350 .take(n)
351 .collect::<Vec<_>>()
352 };
353
354 PublicEntriesResponse::PublicEntries(entries)
355 }
356
357 fn handle_get_public_entry(&self, req: &rpc::GetPublicEntry) -> PublicEntryResponse {
358 let overlay = match self.public_overlays.get(&req.overlay_id) {
360 Some(overlay) => overlay,
361 None => return PublicEntryResponse::OverlayNotFound,
362 };
363
364 let Some(entry) = overlay.own_signed_entry() else {
365 return PublicEntryResponse::OverlayNotFound;
369 };
370
371 PublicEntryResponse::Found(entry)
372 }
373}