scion_stack/scionstack/
socket.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use bytes::Bytes;
18use chrono::Utc;
19use futures::future::BoxFuture;
20use scion_proto::{
21    address::{ScionAddr, SocketAddr},
22    packet::{ByEndpoint, ScionPacketRaw, ScionPacketScmp, ScionPacketUdp},
23    path::Path,
24    scmp::ScmpMessage,
25};
26
27use super::{NetworkError, UnderlaySocket};
28use crate::{
29    path::{
30        PathStrategy,
31        manager::{CachingPathManager, PathManager, PathWaitError},
32        policy::PathPolicy,
33        ranking::PathRanking,
34    },
35    scionstack::{ScionSocketReceiveError, ScionSocketSendError},
36};
37
38/// A path unaware UDP SCION socket.
39pub struct PathUnawareUdpScionSocket {
40    inner: Box<dyn UnderlaySocket + Sync + Send>,
41}
42
43impl std::fmt::Debug for PathUnawareUdpScionSocket {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("PathUnawareUdpScionSocket")
46            .field("local_addr", &self.inner.local_addr())
47            .finish()
48    }
49}
50
51impl PathUnawareUdpScionSocket {
52    pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
53        Self { inner: socket }
54    }
55
56    /// Send a SCION UDP datagram via the given path.
57    pub fn send_to_via<'a>(
58        &'a self,
59        payload: &[u8],
60        destination: SocketAddr,
61        path: &Path<&[u8]>,
62    ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
63        let packet = match ScionPacketUdp::new(
64            ByEndpoint {
65                source: self.inner.local_addr(),
66                destination,
67            },
68            path.data_plane_path.to_bytes_path(),
69            Bytes::copy_from_slice(payload),
70        ) {
71            Ok(packet) => packet,
72            Err(e) => {
73                return Box::pin(async move {
74                    Err(ScionSocketSendError::InvalidPacket(
75                        format!("error encoding packet: {e}").into(),
76                    ))
77                });
78            }
79        }
80        .into();
81        self.inner.send(packet)
82    }
83
84    /// Receive a SCION packet with the sender and path.
85    #[allow(clippy::type_complexity)]
86    pub fn recv_from_with_path<'a>(
87        &'a self,
88        buffer: &'a mut [u8],
89        path_buffer: &'a mut [u8],
90    ) -> BoxFuture<'a, Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
91    {
92        Box::pin(async move {
93            loop {
94                let packet = self.inner.recv().await?;
95                let packet: ScionPacketUdp = match packet.try_into() {
96                    Ok(packet) => packet,
97                    Err(e) => {
98                        tracing::debug!(error = %e, "Received invalid UDP packet, skipping");
99                        continue;
100                    }
101                };
102                let src_addr = match packet.headers.address.source() {
103                    Some(source) => SocketAddr::new(source, packet.src_port()),
104                    None => {
105                        tracing::debug!("Received packet without source address header, skipping");
106                        continue;
107                    }
108                };
109                tracing::trace!(
110                    src = %src_addr,
111                    length = packet.datagram.payload.len(),
112                    "received packet",
113                );
114
115                let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
116                buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
117
118                if path_buffer.len() < packet.headers.path.raw().len() {
119                    return Err(ScionSocketReceiveError::PathBufTooSmall);
120                }
121
122                let dataplane_path = packet
123                    .headers
124                    .path
125                    .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
126
127                // Note, that we do not have the next hop address of the path.
128                // A socket that uses more than one tunnel will need to distinguish between
129                // packets received on different tunnels.
130                let path = Path::new(dataplane_path, packet.headers.address.ia, None);
131
132                return Ok((packet.datagram.payload.len(), src_addr, path));
133            }
134        })
135    }
136
137    /// Receive a SCION packet with the sender.
138    pub fn recv_from<'a>(
139        &'a self,
140        buffer: &'a mut [u8],
141    ) -> BoxFuture<'a, Result<(usize, SocketAddr), ScionSocketReceiveError>> {
142        Box::pin(async move {
143            loop {
144                let packet = self.inner.recv().await?;
145                let packet: ScionPacketUdp = match packet.try_into() {
146                    Ok(packet) => packet,
147                    Err(e) => {
148                        tracing::debug!(error = %e, "Received invalid UDP packet, dropping");
149                        continue;
150                    }
151                };
152                let src_addr = match packet.headers.address.source() {
153                    Some(source) => SocketAddr::new(source, packet.src_port()),
154                    None => {
155                        tracing::debug!("Received packet without source address header, dropping");
156                        continue;
157                    }
158                };
159
160                tracing::trace!(
161                    src = %src_addr,
162                    length = packet.datagram.payload.len(),
163                    buffer_size = buffer.len(),
164                    "received packet",
165                );
166
167                let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
168                buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
169
170                return Ok((packet.datagram.payload.len(), src_addr));
171            }
172        })
173    }
174
175    /// The local address the socket is bound to.
176    fn local_addr(&self) -> SocketAddr {
177        self.inner.local_addr()
178    }
179}
180
181/// A SCMP SCION socket.
182pub struct ScmpScionSocket {
183    inner: Box<dyn UnderlaySocket + Sync + Send>,
184}
185
186impl ScmpScionSocket {
187    pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
188        Self { inner: socket }
189    }
190}
191
192impl ScmpScionSocket {
193    /// Send a SCMP message to the destination via the given path.
194    pub fn send_to_via<'a>(
195        &'a self,
196        message: ScmpMessage,
197        destination: ScionAddr,
198        path: &Path<&[u8]>,
199    ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
200        let packet = match ScionPacketScmp::new(
201            ByEndpoint {
202                source: self.inner.local_addr().scion_address(),
203                destination,
204            },
205            path.data_plane_path.to_bytes_path(),
206            message,
207        ) {
208            Ok(packet) => packet,
209            Err(e) => {
210                return Box::pin(async move {
211                    Err(ScionSocketSendError::InvalidPacket(
212                        format!("error encoding packet: {e}").into(),
213                    ))
214                });
215            }
216        };
217        let packet = packet.into();
218        Box::pin(async move { self.inner.send(packet).await })
219    }
220
221    /// Receive a SCMP message with the sender and path.
222    #[allow(clippy::type_complexity)]
223    pub fn recv_from_with_path<'a>(
224        &'a self,
225        path_buffer: &'a mut [u8],
226    ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
227    {
228        Box::pin(async move {
229            loop {
230                let packet = self.inner.recv().await?;
231                let packet: ScionPacketScmp = match packet.try_into() {
232                    Ok(packet) => packet,
233                    Err(e) => {
234                        tracing::debug!(error = %e, "Received invalid SCMP packet, dropping");
235                        continue;
236                    }
237                };
238                let src_addr = match packet.headers.address.source() {
239                    Some(source) => source,
240                    None => {
241                        tracing::debug!("Received packet without source address header, dropping");
242                        continue;
243                    }
244                };
245
246                if path_buffer.len() < packet.headers.path.raw().len() {
247                    return Err(ScionSocketReceiveError::PathBufTooSmall);
248                }
249                let dataplane_path = packet
250                    .headers
251                    .path
252                    .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
253                let path = Path::new(dataplane_path, packet.headers.address.ia, None);
254
255                return Ok((packet.message, src_addr, path));
256            }
257        })
258    }
259
260    /// Receive a SCMP message with the sender.
261    pub fn recv_from<'a>(
262        &'a self,
263    ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr), ScionSocketReceiveError>> {
264        Box::pin(async move {
265            loop {
266                let packet = self.inner.recv().await?;
267                let packet: ScionPacketScmp = match packet.try_into() {
268                    Ok(packet) => packet,
269                    Err(e) => {
270                        tracing::debug!(error = %e, "Received invalid SCMP packet, skipping");
271                        continue;
272                    }
273                };
274                let src_addr = match packet.headers.address.source() {
275                    Some(source) => source,
276                    None => {
277                        tracing::debug!("Received packet without source address header, skipping");
278                        continue;
279                    }
280                };
281                return Ok((packet.message, src_addr));
282            }
283        })
284    }
285
286    /// Return the local socket address.
287    pub fn local_addr(&self) -> SocketAddr {
288        self.inner.local_addr()
289    }
290}
291
292/// A raw SCION socket.
293pub struct RawScionSocket {
294    inner: Box<dyn UnderlaySocket>,
295}
296
297impl RawScionSocket {
298    pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
299        Self { inner: socket }
300    }
301}
302
303impl RawScionSocket {
304    /// Send a raw SCION packet.
305    pub fn send<'a>(
306        &'a self,
307        packet: ScionPacketRaw,
308    ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
309        self.inner.send(packet)
310    }
311
312    /// Receive a raw SCION packet.
313    pub fn recv<'a>(&'a self) -> BoxFuture<'a, Result<ScionPacketRaw, ScionSocketReceiveError>> {
314        self.inner.recv()
315    }
316
317    /// Return the local socket address.
318    pub fn local_addr(&self) -> SocketAddr {
319        self.inner.local_addr()
320    }
321}
322
323/// Configuration for a path aware socket.
324#[derive(Default)]
325pub struct SocketConfig {
326    pub(crate) path_strategy: PathStrategy,
327}
328impl SocketConfig {
329    /// Creates a new default socket configuration.
330    pub fn new() -> Self {
331        Self::default()
332    }
333
334    /// Adds a path policy.
335    ///
336    /// Path policies can restrict the set of usable paths based on their characteristics.
337    /// E.g. filtering out paths that go through certain ASes.
338    ///
339    /// See [`HopPatternPolicy`](scion_proto::path::policy::hop_pattern::HopPatternPolicy) and
340    /// [`AclPolicy`](scion_proto::path::policy::acl::AclPolicy)
341    pub fn with_path_policy(mut self, policy: impl PathPolicy) -> Self {
342        self.path_strategy.add_policy(policy);
343        self
344    }
345
346    /// Add a path ranking strategy.
347    ///
348    /// Path Rankings prioritize paths based on their characteristics.
349    ///
350    /// Ranking priority is determined by the order in which they are added to the stack, the first
351    /// having the highest priority.
352    ///
353    /// If no ranking strategies are added, ranking will default to
354    /// [`Shortest`](crate::path::ranking::Shortest).
355    pub fn with_path_ranking(mut self, ranking: impl PathRanking) -> Self {
356        self.path_strategy.add_ranking(ranking);
357        self
358    }
359}
360
361/// A path aware UDP socket generic over the underlay socket and path manager.
362pub struct UdpScionSocket<P: PathManager = CachingPathManager> {
363    socket: PathUnawareUdpScionSocket,
364    pather: Arc<P>,
365    remote_addr: Option<SocketAddr>,
366}
367
368impl<P: PathManager> std::fmt::Debug for UdpScionSocket<P> {
369    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370        f.debug_struct("UdpScionSocket")
371            .field("local_addr", &self.socket.local_addr())
372            .field("remote_addr", &self.remote_addr)
373            .finish()
374    }
375}
376
377impl<P: PathManager> UdpScionSocket<P> {
378    /// Creates a new path aware UDP SCION socket.
379    pub fn new(
380        socket: PathUnawareUdpScionSocket,
381        pather: Arc<P>,
382        remote_addr: Option<SocketAddr>,
383    ) -> Self {
384        Self {
385            socket,
386            pather,
387            remote_addr,
388        }
389    }
390
391    /// Connects the socket to a remote address.
392    pub fn connect(self, remote_addr: SocketAddr) -> Self {
393        Self {
394            remote_addr: Some(remote_addr),
395            ..self
396        }
397    }
398
399    /// Send a datagram to the connected remote address.
400    pub async fn send(&self, payload: &[u8]) -> Result<(), ScionSocketSendError> {
401        if let Some(remote_addr) = self.remote_addr {
402            self.send_to(payload, remote_addr).await
403        } else {
404            Err(ScionSocketSendError::NotConnected)
405        }
406    }
407
408    /// Send a datagram to the specified destination.
409    pub async fn send_to(
410        &self,
411        payload: &[u8],
412        destination: SocketAddr,
413    ) -> Result<(), ScionSocketSendError> {
414        let path = &self
415            .pather
416            .path_wait(
417                self.socket.local_addr().isd_asn(),
418                destination.isd_asn(),
419                Utc::now(),
420            )
421            .await
422            .map_err(|e| {
423                match e {
424                    PathWaitError::FetchFailed(e) => {
425                        ScionSocketSendError::PathLookupError(e.into())
426                    }
427                    PathWaitError::NoPathFound => {
428                        ScionSocketSendError::NetworkUnreachable(
429                            NetworkError::DestinationUnreachable("No path found".to_string()),
430                        )
431                    }
432                }
433            })?;
434        self.socket
435            .send_to_via(payload, destination, &path.to_slice_path())
436            .await
437    }
438
439    /// Send a datagram to the specified destination via the specified path.
440    pub async fn send_to_via(
441        &self,
442        payload: &[u8],
443        destination: SocketAddr,
444        path: &Path<&[u8]>,
445    ) -> Result<(), ScionSocketSendError> {
446        self.socket.send_to_via(payload, destination, path).await
447    }
448
449    /// Receive a datagram from any address, along with the sender address and path.
450    pub async fn recv_from_with_path<'a>(
451        &'a self,
452        buffer: &'a mut [u8],
453        path_buffer: &'a mut [u8],
454    ) -> Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError> {
455        let (len, sender_addr, path): (usize, SocketAddr, Path<&mut [u8]>) =
456            self.socket.recv_from_with_path(buffer, path_buffer).await?;
457
458        match path.to_reversed() {
459            Ok(reversed_path) => {
460                // Register the path for future use
461                self.pather.register_path(
462                    self.socket.local_addr().isd_asn(),
463                    sender_addr.isd_asn(),
464                    Utc::now(),
465                    reversed_path,
466                );
467            }
468            Err(e) => {
469                tracing::trace!(error = ?e, "Failed to reverse path for registration")
470            }
471        }
472
473        tracing::trace!(
474            src = %self.socket.local_addr(),
475            dst = %sender_addr,
476            "Registered reverse path",
477        );
478
479        Ok((len, sender_addr, path))
480    }
481
482    /// Receive a datagram from the connected remote address and write it into the provided buffer.
483    pub async fn recv_from(
484        &self,
485        buffer: &mut [u8],
486    ) -> Result<(usize, SocketAddr), ScionSocketReceiveError> {
487        // For this method, we need to get the path to register it, but we don't return it
488        let mut path_buffer = [0u8; 1024]; // Temporary buffer for path
489        let (len, sender_addr, _) = self.recv_from_with_path(buffer, &mut path_buffer).await?;
490        Ok((len, sender_addr))
491    }
492
493    /// Receive a datagram from the connected remote address.
494    ///
495    /// Datagrams from other addresses are silently discarded.
496    pub async fn recv(&self, buffer: &mut [u8]) -> Result<usize, ScionSocketReceiveError> {
497        if self.remote_addr.is_none() {
498            return Err(ScionSocketReceiveError::NotConnected);
499        }
500        loop {
501            let (len, sender_addr) = self.recv_from(buffer).await?;
502            match self.remote_addr {
503                Some(remote_addr) => {
504                    if sender_addr == remote_addr {
505                        return Ok(len);
506                    }
507                }
508                None => return Err(ScionSocketReceiveError::NotConnected),
509            }
510        }
511    }
512
513    /// Returns the local socket address.
514    pub fn local_addr(&self) -> SocketAddr {
515        self.socket.local_addr()
516    }
517}