1use std::sync::Arc;
16
17use bytes::Bytes;
18use chrono::Utc;
19use futures::future::BoxFuture;
20use scion_proto::{
21 address::{ScionAddr, SocketAddr},
22 packet::{ByEndpoint, ScionPacketRaw, ScionPacketScmp, ScionPacketUdp},
23 path::Path,
24 scmp::ScmpMessage,
25};
26
27use super::{NetworkError, UnderlaySocket};
28use crate::{
29 path::{
30 PathStrategy,
31 manager::{CachingPathManager, PathManager, PathWaitError},
32 policy::PathPolicy,
33 ranking::PathRanking,
34 },
35 scionstack::{ScionSocketReceiveError, ScionSocketSendError},
36};
37
38pub struct PathUnawareUdpScionSocket {
40 inner: Box<dyn UnderlaySocket + Sync + Send>,
41}
42
43impl std::fmt::Debug for PathUnawareUdpScionSocket {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct("PathUnawareUdpScionSocket")
46 .field("local_addr", &self.inner.local_addr())
47 .finish()
48 }
49}
50
51impl PathUnawareUdpScionSocket {
52 pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
53 Self { inner: socket }
54 }
55
56 pub fn send_to_via<'a>(
58 &'a self,
59 payload: &[u8],
60 destination: SocketAddr,
61 path: &Path<&[u8]>,
62 ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
63 let packet = match ScionPacketUdp::new(
64 ByEndpoint {
65 source: self.inner.local_addr(),
66 destination,
67 },
68 path.data_plane_path.to_bytes_path(),
69 Bytes::copy_from_slice(payload),
70 ) {
71 Ok(packet) => packet,
72 Err(e) => {
73 return Box::pin(async move {
74 Err(ScionSocketSendError::InvalidPacket(
75 format!("error encoding packet: {e}").into(),
76 ))
77 });
78 }
79 }
80 .into();
81 self.inner.send(packet)
82 }
83
84 #[allow(clippy::type_complexity)]
86 pub fn recv_from_with_path<'a>(
87 &'a self,
88 buffer: &'a mut [u8],
89 path_buffer: &'a mut [u8],
90 ) -> BoxFuture<'a, Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
91 {
92 Box::pin(async move {
93 loop {
94 let packet = self.inner.recv().await?;
95 let packet: ScionPacketUdp = match packet.try_into() {
96 Ok(packet) => packet,
97 Err(e) => {
98 tracing::debug!(error = %e, "Received invalid UDP packet, skipping");
99 continue;
100 }
101 };
102 let src_addr = match packet.headers.address.source() {
103 Some(source) => SocketAddr::new(source, packet.src_port()),
104 None => {
105 tracing::debug!("Received packet without source address header, skipping");
106 continue;
107 }
108 };
109 tracing::trace!(
110 src = %src_addr,
111 length = packet.datagram.payload.len(),
112 "received packet",
113 );
114
115 let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
116 buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
117
118 if path_buffer.len() < packet.headers.path.raw().len() {
119 return Err(ScionSocketReceiveError::PathBufTooSmall);
120 }
121
122 let dataplane_path = packet
123 .headers
124 .path
125 .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
126
127 let path = Path::new(dataplane_path, packet.headers.address.ia, None);
131
132 return Ok((packet.datagram.payload.len(), src_addr, path));
133 }
134 })
135 }
136
137 pub fn recv_from<'a>(
139 &'a self,
140 buffer: &'a mut [u8],
141 ) -> BoxFuture<'a, Result<(usize, SocketAddr), ScionSocketReceiveError>> {
142 Box::pin(async move {
143 loop {
144 let packet = self.inner.recv().await?;
145 let packet: ScionPacketUdp = match packet.try_into() {
146 Ok(packet) => packet,
147 Err(e) => {
148 tracing::debug!(error = %e, "Received invalid UDP packet, dropping");
149 continue;
150 }
151 };
152 let src_addr = match packet.headers.address.source() {
153 Some(source) => SocketAddr::new(source, packet.src_port()),
154 None => {
155 tracing::debug!("Received packet without source address header, dropping");
156 continue;
157 }
158 };
159
160 tracing::trace!(
161 src = %src_addr,
162 length = packet.datagram.payload.len(),
163 buffer_size = buffer.len(),
164 "received packet",
165 );
166
167 let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
168 buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
169
170 return Ok((packet.datagram.payload.len(), src_addr));
171 }
172 })
173 }
174
175 fn local_addr(&self) -> SocketAddr {
177 self.inner.local_addr()
178 }
179}
180
181pub struct ScmpScionSocket {
183 inner: Box<dyn UnderlaySocket + Sync + Send>,
184}
185
186impl ScmpScionSocket {
187 pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
188 Self { inner: socket }
189 }
190}
191
192impl ScmpScionSocket {
193 pub fn send_to_via<'a>(
195 &'a self,
196 message: ScmpMessage,
197 destination: ScionAddr,
198 path: &Path<&[u8]>,
199 ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
200 let packet = match ScionPacketScmp::new(
201 ByEndpoint {
202 source: self.inner.local_addr().scion_address(),
203 destination,
204 },
205 path.data_plane_path.to_bytes_path(),
206 message,
207 ) {
208 Ok(packet) => packet,
209 Err(e) => {
210 return Box::pin(async move {
211 Err(ScionSocketSendError::InvalidPacket(
212 format!("error encoding packet: {e}").into(),
213 ))
214 });
215 }
216 };
217 let packet = packet.into();
218 Box::pin(async move { self.inner.send(packet).await })
219 }
220
221 #[allow(clippy::type_complexity)]
223 pub fn recv_from_with_path<'a>(
224 &'a self,
225 path_buffer: &'a mut [u8],
226 ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
227 {
228 Box::pin(async move {
229 loop {
230 let packet = self.inner.recv().await?;
231 let packet: ScionPacketScmp = match packet.try_into() {
232 Ok(packet) => packet,
233 Err(e) => {
234 tracing::debug!(error = %e, "Received invalid SCMP packet, dropping");
235 continue;
236 }
237 };
238 let src_addr = match packet.headers.address.source() {
239 Some(source) => source,
240 None => {
241 tracing::debug!("Received packet without source address header, dropping");
242 continue;
243 }
244 };
245
246 if path_buffer.len() < packet.headers.path.raw().len() {
247 return Err(ScionSocketReceiveError::PathBufTooSmall);
248 }
249 let dataplane_path = packet
250 .headers
251 .path
252 .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
253 let path = Path::new(dataplane_path, packet.headers.address.ia, None);
254
255 return Ok((packet.message, src_addr, path));
256 }
257 })
258 }
259
260 pub fn recv_from<'a>(
262 &'a self,
263 ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr), ScionSocketReceiveError>> {
264 Box::pin(async move {
265 loop {
266 let packet = self.inner.recv().await?;
267 let packet: ScionPacketScmp = match packet.try_into() {
268 Ok(packet) => packet,
269 Err(e) => {
270 tracing::debug!(error = %e, "Received invalid SCMP packet, skipping");
271 continue;
272 }
273 };
274 let src_addr = match packet.headers.address.source() {
275 Some(source) => source,
276 None => {
277 tracing::debug!("Received packet without source address header, skipping");
278 continue;
279 }
280 };
281 return Ok((packet.message, src_addr));
282 }
283 })
284 }
285
286 pub fn local_addr(&self) -> SocketAddr {
288 self.inner.local_addr()
289 }
290}
291
292pub struct RawScionSocket {
294 inner: Box<dyn UnderlaySocket>,
295}
296
297impl RawScionSocket {
298 pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
299 Self { inner: socket }
300 }
301}
302
303impl RawScionSocket {
304 pub fn send<'a>(
306 &'a self,
307 packet: ScionPacketRaw,
308 ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
309 self.inner.send(packet)
310 }
311
312 pub fn recv<'a>(&'a self) -> BoxFuture<'a, Result<ScionPacketRaw, ScionSocketReceiveError>> {
314 self.inner.recv()
315 }
316
317 pub fn local_addr(&self) -> SocketAddr {
319 self.inner.local_addr()
320 }
321}
322
323#[derive(Default)]
325pub struct SocketConfig {
326 pub(crate) path_strategy: PathStrategy,
327}
328impl SocketConfig {
329 pub fn new() -> Self {
331 Self::default()
332 }
333
334 pub fn with_path_policy(mut self, policy: impl PathPolicy) -> Self {
342 self.path_strategy.add_policy(policy);
343 self
344 }
345
346 pub fn with_path_ranking(mut self, ranking: impl PathRanking) -> Self {
356 self.path_strategy.add_ranking(ranking);
357 self
358 }
359}
360
361pub struct UdpScionSocket<P: PathManager = CachingPathManager> {
363 socket: PathUnawareUdpScionSocket,
364 pather: Arc<P>,
365 remote_addr: Option<SocketAddr>,
366}
367
368impl<P: PathManager> std::fmt::Debug for UdpScionSocket<P> {
369 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370 f.debug_struct("UdpScionSocket")
371 .field("local_addr", &self.socket.local_addr())
372 .field("remote_addr", &self.remote_addr)
373 .finish()
374 }
375}
376
377impl<P: PathManager> UdpScionSocket<P> {
378 pub fn new(
380 socket: PathUnawareUdpScionSocket,
381 pather: Arc<P>,
382 remote_addr: Option<SocketAddr>,
383 ) -> Self {
384 Self {
385 socket,
386 pather,
387 remote_addr,
388 }
389 }
390
391 pub fn connect(self, remote_addr: SocketAddr) -> Self {
393 Self {
394 remote_addr: Some(remote_addr),
395 ..self
396 }
397 }
398
399 pub async fn send(&self, payload: &[u8]) -> Result<(), ScionSocketSendError> {
401 if let Some(remote_addr) = self.remote_addr {
402 self.send_to(payload, remote_addr).await
403 } else {
404 Err(ScionSocketSendError::NotConnected)
405 }
406 }
407
408 pub async fn send_to(
410 &self,
411 payload: &[u8],
412 destination: SocketAddr,
413 ) -> Result<(), ScionSocketSendError> {
414 let path = &self
415 .pather
416 .path_wait(
417 self.socket.local_addr().isd_asn(),
418 destination.isd_asn(),
419 Utc::now(),
420 )
421 .await
422 .map_err(|e| {
423 match e {
424 PathWaitError::FetchFailed(e) => {
425 ScionSocketSendError::PathLookupError(e.into())
426 }
427 PathWaitError::NoPathFound => {
428 ScionSocketSendError::NetworkUnreachable(
429 NetworkError::DestinationUnreachable("No path found".to_string()),
430 )
431 }
432 }
433 })?;
434 self.socket
435 .send_to_via(payload, destination, &path.to_slice_path())
436 .await
437 }
438
439 pub async fn send_to_via(
441 &self,
442 payload: &[u8],
443 destination: SocketAddr,
444 path: &Path<&[u8]>,
445 ) -> Result<(), ScionSocketSendError> {
446 self.socket.send_to_via(payload, destination, path).await
447 }
448
449 pub async fn recv_from_with_path<'a>(
451 &'a self,
452 buffer: &'a mut [u8],
453 path_buffer: &'a mut [u8],
454 ) -> Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError> {
455 let (len, sender_addr, path): (usize, SocketAddr, Path<&mut [u8]>) =
456 self.socket.recv_from_with_path(buffer, path_buffer).await?;
457
458 match path.to_reversed() {
459 Ok(reversed_path) => {
460 self.pather.register_path(
462 self.socket.local_addr().isd_asn(),
463 sender_addr.isd_asn(),
464 Utc::now(),
465 reversed_path,
466 );
467 }
468 Err(e) => {
469 tracing::trace!(error = ?e, "Failed to reverse path for registration")
470 }
471 }
472
473 tracing::trace!(
474 src = %self.socket.local_addr(),
475 dst = %sender_addr,
476 "Registered reverse path",
477 );
478
479 Ok((len, sender_addr, path))
480 }
481
482 pub async fn recv_from(
484 &self,
485 buffer: &mut [u8],
486 ) -> Result<(usize, SocketAddr), ScionSocketReceiveError> {
487 let mut path_buffer = [0u8; 1024]; let (len, sender_addr, _) = self.recv_from_with_path(buffer, &mut path_buffer).await?;
490 Ok((len, sender_addr))
491 }
492
493 pub async fn recv(&self, buffer: &mut [u8]) -> Result<usize, ScionSocketReceiveError> {
497 if self.remote_addr.is_none() {
498 return Err(ScionSocketReceiveError::NotConnected);
499 }
500 loop {
501 let (len, sender_addr) = self.recv_from(buffer).await?;
502 match self.remote_addr {
503 Some(remote_addr) => {
504 if sender_addr == remote_addr {
505 return Ok(len);
506 }
507 }
508 None => return Err(ScionSocketReceiveError::NotConnected),
509 }
510 }
511 }
512
513 pub fn local_addr(&self) -> SocketAddr {
515 self.socket.local_addr()
516 }
517}