Skip to main content

scion_stack/scionstack/
socket.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! SCION socket types.
15
16use std::{sync::Arc, time::Duration};
17
18use bytes::Bytes;
19use chrono::Utc;
20use futures::future::BoxFuture;
21use scion_proto::{
22    address::{ScionAddr, SocketAddr},
23    datagram::UdpMessage,
24    packet::{ByEndpoint, ScionPacketRaw, ScionPacketScmp, ScionPacketUdp},
25    path::Path,
26    scmp::{SCMP_PROTOCOL_NUMBER, ScmpMessage},
27};
28use scion_sdk_quic_scion::socket::{BoxedSocketError, GenericScionUdpSocket};
29
30use super::UnderlaySocket;
31use crate::{
32    path::manager::{MultiPathManager, traits::PathManager},
33    scionstack::{
34        MIN_PATH_BUFFER_SIZE, ScionSocketConnectError, ScionSocketReceiveError,
35        ScionSocketSendError, scmp_handler::ScmpHandler,
36    },
37    types::Subscribers,
38};
39
40/// A path unaware UDP SCION socket.
41pub struct PathUnawareUdpScionSocket {
42    inner: Box<dyn UnderlaySocket + Sync + Send>,
43    /// The SCMP handlers.
44    scmp_handlers: Vec<Box<dyn ScmpHandler>>,
45}
46
47impl std::fmt::Debug for PathUnawareUdpScionSocket {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        f.debug_struct("PathUnawareUdpScionSocket")
50            .field("local_addr", &self.inner.local_addr())
51            .finish()
52    }
53}
54
55impl PathUnawareUdpScionSocket {
56    pub(crate) fn new(
57        socket: Box<dyn UnderlaySocket + Sync + Send>,
58        scmp_handlers: Vec<Box<dyn ScmpHandler>>,
59    ) -> Self {
60        Self {
61            inner: socket,
62            scmp_handlers,
63        }
64    }
65
66    /// Send a SCION UDP datagram via the given path.
67    ///
68    /// # Cancel safety
69    ///
70    /// This method is cancel-safe. If the future is dropped before completion, the packet may
71    /// be silently lost, but no socket state is corrupted and the socket remains usable.
72    pub fn send_to_via<'a>(
73        &'a self,
74        payload: &[u8],
75        destination: SocketAddr,
76        path: &Path<&[u8]>,
77    ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
78        let packet = match ScionPacketUdp::new(
79            ByEndpoint {
80                source: self.inner.local_addr(),
81                destination,
82            },
83            path.data_plane_path.to_bytes_path(),
84            Bytes::copy_from_slice(payload),
85        ) {
86            Ok(packet) => packet,
87            Err(e) => {
88                return Box::pin(async move {
89                    Err(ScionSocketSendError::InvalidPacket(
90                        format!("error encoding packet: {e}").into(),
91                    ))
92                });
93            }
94        }
95        .into();
96        self.inner.send(packet)
97    }
98
99    /// Receive a SCION packet with the sender and path.
100    ///
101    /// # Cancel safety
102    ///
103    /// This method is cancel-safe. The only await point is the inner underlay receive. If the
104    /// future is dropped while waiting for a packet, no packet data is consumed and `buffer`
105    /// and `path_buffer` are left unmodified. If a packet has already been received (i.e., the
106    /// future is dropped after data has been written into the buffers), this cannot occur in
107    /// practice because those steps run synchronously within a single `poll` invocation.
108    #[allow(clippy::type_complexity)]
109    pub fn recv_from_with_path<'a>(
110        &'a self,
111        buffer: &'a mut [u8],
112        path_buffer: &'a mut [u8],
113    ) -> BoxFuture<'a, Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
114    {
115        Box::pin(async move {
116            loop {
117                let packet = self.inner.recv().await?;
118
119                let packet = match packet.headers.common.next_header {
120                    UdpMessage::PROTOCOL_NUMBER => packet,
121                    SCMP_PROTOCOL_NUMBER => {
122                        tracing::debug!("SCMP packet received, forwarding to SCMP handlers");
123                        for handler in &self.scmp_handlers {
124                            if let Some(reply) = handler.handle(packet.clone())
125                                && let Err(e) = self.inner.try_send(reply)
126                            {
127                                tracing::warn!(error = %e, "failed to send SCMP reply");
128                            }
129                        }
130                        continue;
131                    }
132                    _ => {
133                        tracing::debug!(next_header = %packet.headers.common.next_header, "Packet with unknown next layer protocol, skipping");
134                        continue;
135                    }
136                };
137
138                let packet: ScionPacketUdp = match packet.try_into() {
139                    Ok(packet) => packet,
140                    Err(e) => {
141                        tracing::debug!(error = %e, "Received invalid UDP packet, skipping");
142                        continue;
143                    }
144                };
145                let src_addr = match packet.headers.address.source() {
146                    Some(source) => SocketAddr::new(source, packet.src_port()),
147                    None => {
148                        tracing::debug!("Received packet without source address header, skipping");
149                        continue;
150                    }
151                };
152                tracing::trace!(
153                    src = %src_addr,
154                    length = packet.datagram.payload.len(),
155                    "received packet",
156                );
157
158                let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
159                buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
160
161                if path_buffer.len() < packet.headers.path.raw().len() {
162                    return Err(ScionSocketReceiveError::PathBufTooSmall);
163                }
164
165                let dataplane_path = packet
166                    .headers
167                    .path
168                    .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
169
170                // Note, that we do not have the next hop address of the path.
171                // A socket that uses more than one tunnel will need to distinguish between
172                // packets received on different tunnels.
173                let path = Path::new(dataplane_path, packet.headers.address.ia, None);
174
175                return Ok((packet.datagram.payload.len(), src_addr, path));
176            }
177        })
178    }
179
180    /// Receive a SCION packet with the sender.
181    ///
182    /// # Cancel safety
183    ///
184    /// This method is cancel-safe. If the future is dropped while waiting for a packet, no
185    /// packet is consumed and `buffer` is left unmodified. The contents of `buffer` are only
186    /// valid after the method returns `Ok`.
187    pub fn recv_from<'a>(
188        &'a self,
189        buffer: &'a mut [u8],
190    ) -> BoxFuture<'a, Result<(usize, SocketAddr), ScionSocketReceiveError>> {
191        Box::pin(async move {
192            loop {
193                let packet = self.inner.recv().await?;
194
195                let packet = match packet.headers.common.next_header {
196                    UdpMessage::PROTOCOL_NUMBER => packet,
197                    SCMP_PROTOCOL_NUMBER => {
198                        tracing::debug!("SCMP packet received, forwarding to SCMP handlers");
199                        for handler in &self.scmp_handlers {
200                            if let Some(reply) = handler.handle(packet.clone())
201                                && let Err(e) = self.inner.try_send(reply)
202                            {
203                                tracing::warn!(error = %e, "failed to send SCMP reply");
204                            }
205                        }
206                        continue;
207                    }
208                    _ => {
209                        tracing::debug!(next_header = %packet.headers.common.next_header, "Packet with unknown next layer protocol, skipping");
210                        continue;
211                    }
212                };
213
214                let packet: ScionPacketUdp = match packet.try_into() {
215                    Ok(packet) => packet,
216                    Err(e) => {
217                        tracing::debug!(error = %e, "Received invalid UDP packet, dropping");
218                        continue;
219                    }
220                };
221                let src_addr = match packet.headers.address.source() {
222                    Some(source) => SocketAddr::new(source, packet.src_port()),
223                    None => {
224                        tracing::debug!("Received packet without source address header, dropping");
225                        continue;
226                    }
227                };
228
229                tracing::trace!(
230                    src = %src_addr,
231                    length = packet.datagram.payload.len(),
232                    buffer_size = buffer.len(),
233                    "received packet",
234                );
235
236                let max_read = std::cmp::min(buffer.len(), packet.datagram.payload.len());
237                buffer[..max_read].copy_from_slice(&packet.datagram.payload[..max_read]);
238
239                return Ok((packet.datagram.payload.len(), src_addr));
240            }
241        })
242    }
243
244    /// The local address the socket is bound to.
245    fn local_addr(&self) -> SocketAddr {
246        self.inner.local_addr()
247    }
248}
249
250/// A SCMP SCION socket.
251pub struct ScmpScionSocket {
252    inner: Box<dyn UnderlaySocket + Sync + Send>,
253}
254
255impl ScmpScionSocket {
256    pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
257        Self { inner: socket }
258    }
259}
260
261impl ScmpScionSocket {
262    /// Send a SCMP message to the destination via the given path.
263    pub fn send_to_via<'a>(
264        &'a self,
265        message: ScmpMessage,
266        destination: ScionAddr,
267        path: &Path<&[u8]>,
268    ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
269        let packet = match ScionPacketScmp::new(
270            ByEndpoint {
271                source: self.inner.local_addr().scion_address(),
272                destination,
273            },
274            path.data_plane_path.to_bytes_path(),
275            message,
276        ) {
277            Ok(packet) => packet,
278            Err(e) => {
279                return Box::pin(async move {
280                    Err(ScionSocketSendError::InvalidPacket(
281                        format!("error encoding packet: {e}").into(),
282                    ))
283                });
284            }
285        };
286        let packet = packet.into();
287        Box::pin(async move { self.inner.send(packet).await })
288    }
289
290    /// Receive a SCMP message with the sender and path.
291    #[allow(clippy::type_complexity)]
292    pub fn recv_from_with_path<'a>(
293        &'a self,
294        path_buffer: &'a mut [u8],
295    ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr, Path<&'a mut [u8]>), ScionSocketReceiveError>>
296    {
297        Box::pin(async move {
298            loop {
299                let packet = self.inner.recv().await?;
300                let packet: ScionPacketScmp = match packet.try_into() {
301                    Ok(packet) => packet,
302                    Err(e) => {
303                        tracing::debug!(error = %e, "Received invalid SCMP packet, dropping");
304                        continue;
305                    }
306                };
307                let src_addr = match packet.headers.address.source() {
308                    Some(source) => source,
309                    None => {
310                        tracing::debug!("Received packet without source address header, dropping");
311                        continue;
312                    }
313                };
314
315                if path_buffer.len() < packet.headers.path.raw().len() {
316                    return Err(ScionSocketReceiveError::PathBufTooSmall);
317                }
318                let dataplane_path = packet
319                    .headers
320                    .path
321                    .copy_to_slice(&mut path_buffer[..packet.headers.path.raw().len()]);
322                let path = Path::new(dataplane_path, packet.headers.address.ia, None);
323
324                return Ok((packet.message, src_addr, path));
325            }
326        })
327    }
328
329    /// Receive a SCMP message with the sender.
330    pub fn recv_from<'a>(
331        &'a self,
332    ) -> BoxFuture<'a, Result<(ScmpMessage, ScionAddr), ScionSocketReceiveError>> {
333        Box::pin(async move {
334            loop {
335                let packet = self.inner.recv().await?;
336                let packet: ScionPacketScmp = match packet.try_into() {
337                    Ok(packet) => packet,
338                    Err(e) => {
339                        tracing::debug!(error = %e, "Received invalid SCMP packet, skipping");
340                        continue;
341                    }
342                };
343                let src_addr = match packet.headers.address.source() {
344                    Some(source) => source,
345                    None => {
346                        tracing::debug!("Received packet without source address header, skipping");
347                        continue;
348                    }
349                };
350                return Ok((packet.message, src_addr));
351            }
352        })
353    }
354
355    /// Return the local socket address.
356    pub fn local_addr(&self) -> SocketAddr {
357        self.inner.local_addr()
358    }
359}
360
361/// A raw SCION socket.
362pub struct RawScionSocket {
363    inner: Box<dyn UnderlaySocket>,
364}
365
366impl RawScionSocket {
367    pub(crate) fn new(socket: Box<dyn UnderlaySocket + Sync + Send>) -> Self {
368        Self { inner: socket }
369    }
370}
371
372impl RawScionSocket {
373    /// Send a raw SCION packet.
374    pub fn send<'a>(
375        &'a self,
376        packet: ScionPacketRaw,
377    ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
378        self.inner.send(packet)
379    }
380
381    /// Receive a raw SCION packet.
382    pub fn recv<'a>(&'a self) -> BoxFuture<'a, Result<ScionPacketRaw, ScionSocketReceiveError>> {
383        self.inner.recv()
384    }
385
386    /// Return the local socket address.
387    pub fn local_addr(&self) -> SocketAddr {
388        self.inner.local_addr()
389    }
390}
391
392/// A trait for receiving socket send errors.
393pub trait SendErrorReceiver: Send + Sync {
394    /// Reports an error when sending a packet.
395    /// This function must return immediately and not block.
396    fn report_send_error(&self, error: &ScionSocketSendError);
397}
398
399/// A path aware UDP socket generic over the underlay socket and path manager.
400pub struct UdpScionSocket<P: PathManager = MultiPathManager> {
401    socket: PathUnawareUdpScionSocket,
402    pather: Arc<P>,
403    connect_timeout: Duration,
404    remote_addr: Option<SocketAddr>,
405    send_error_receivers: Subscribers<dyn SendErrorReceiver>,
406}
407
408impl<P: PathManager> std::fmt::Debug for UdpScionSocket<P> {
409    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410        f.debug_struct("UdpScionSocket")
411            .field("local_addr", &self.socket.local_addr())
412            .field("remote_addr", &self.remote_addr)
413            .finish()
414    }
415}
416
417impl<P: PathManager> UdpScionSocket<P> {
418    /// Creates a new path aware UDP SCION socket.
419    pub fn new(
420        socket: PathUnawareUdpScionSocket,
421        pather: Arc<P>,
422        connect_timeout: Duration,
423        send_error_receivers: Subscribers<dyn SendErrorReceiver>,
424    ) -> Self {
425        Self {
426            socket,
427            pather,
428            connect_timeout,
429            remote_addr: None,
430            send_error_receivers,
431        }
432    }
433
434    /// Connects the socket to a remote address.
435    ///
436    /// Ensures a Path to the Destination exists, returns an error if not.
437    ///
438    /// Timeouts after configured `connect_timeout`
439    pub async fn connect(self, remote_addr: SocketAddr) -> Result<Self, ScionSocketConnectError> {
440        // Check that a path exists to destination
441        let _path = self
442            .pather
443            .path_timeout(
444                self.socket.local_addr().isd_asn(),
445                remote_addr.isd_asn(),
446                Utc::now(),
447                self.connect_timeout,
448            )
449            .await?;
450
451        Ok(Self {
452            remote_addr: Some(remote_addr),
453            ..self
454        })
455    }
456
457    /// Send a datagram to the connected remote address.
458    ///
459    /// # Cancel safety
460    ///
461    /// This method is cancel-safe. If the future is dropped before completion, the packet may
462    /// be silently lost, but no socket state is corrupted and the socket remains usable.
463    pub async fn send(&self, payload: &[u8]) -> Result<(), ScionSocketSendError> {
464        if let Some(remote_addr) = self.remote_addr {
465            self.send_to(payload, remote_addr).await
466        } else {
467            Err(ScionSocketSendError::NotConnected)
468        }
469    }
470
471    /// Send a datagram to the specified destination.
472    ///
473    /// # Cancel safety
474    ///
475    /// This method is cancel-safe. It has two await points: the path lookup and the actual send.
476    /// If the future is dropped at either point, no socket state is corrupted and the socket
477    /// remains usable. A packet dropped mid-send is silently lost, which is normal for UDP.
478    pub async fn send_to(
479        &self,
480        payload: &[u8],
481        destination: SocketAddr,
482    ) -> Result<(), ScionSocketSendError> {
483        let path = &self
484            .pather
485            .path_wait(
486                self.socket.local_addr().isd_asn(),
487                destination.isd_asn(),
488                Utc::now(),
489            )
490            .await?;
491        self.socket
492            .send_to_via(payload, destination, &path.to_slice_path())
493            .await
494    }
495
496    /// Send a datagram to the specified destination via the specified path.
497    ///
498    /// # Cancel safety
499    ///
500    /// This method is cancel-safe. If the future is dropped before completion, the packet may
501    /// be silently lost, but no socket state is corrupted and the socket remains usable.
502    pub async fn send_to_via(
503        &self,
504        payload: &[u8],
505        destination: SocketAddr,
506        path: &Path<&[u8]>,
507    ) -> Result<(), ScionSocketSendError> {
508        self.socket
509            .send_to_via(payload, destination, path)
510            .await
511            .inspect_err(|e| {
512                self.send_error_receivers
513                    .for_each(|receiver| receiver.report_send_error(e));
514            })
515    }
516
517    /// Receive a datagram from any address, along with the sender address and path.
518    ///
519    /// # Cancel safety
520    ///
521    /// This method is cancel-safe. The only await point is the inner underlay receive. If the
522    /// future is dropped while waiting for a packet, no packet data is consumed and `buffer`
523    /// and `path_buffer` are left unmodified.
524    ///
525    /// Path registration via the path manager runs synchronously within the same `poll`
526    /// invocation that delivers the received data, so it cannot be independently cancelled.
527    pub async fn recv_from_with_path<'a>(
528        &'a self,
529        buffer: &'a mut [u8],
530        path_buffer: &'a mut [u8],
531    ) -> Result<(usize, SocketAddr, Path<&'a mut [u8]>), ScionSocketReceiveError> {
532        let (len, sender_addr, path): (usize, SocketAddr, Path<&mut [u8]>) =
533            self.socket.recv_from_with_path(buffer, path_buffer).await?;
534
535        match path.to_reversed() {
536            Ok(reversed_path) => {
537                // Register the path for future use
538                self.pather.register_path(
539                    self.socket.local_addr().isd_asn(),
540                    sender_addr.isd_asn(),
541                    Utc::now(),
542                    reversed_path,
543                );
544            }
545            Err(e) => {
546                tracing::trace!(error = ?e, "Failed to reverse path for registration")
547            }
548        }
549
550        tracing::trace!(
551            src = %self.socket.local_addr(),
552            dst = %sender_addr,
553            "Registered reverse path",
554        );
555
556        Ok((len, sender_addr, path))
557    }
558
559    /// Receive a datagram from the connected remote address and write it into the provided buffer.
560    ///
561    /// The path of the received packet is used to register a reverse path with the path manager,
562    /// but is not returned to the caller. Use [`recv_from_with_path`](Self::recv_from_with_path)
563    /// if the path is needed.
564    ///
565    /// # Cancel safety
566    ///
567    /// This method is cancel-safe. If the future is dropped while waiting for a packet, no
568    /// packet is consumed and `buffer` is left unmodified. The contents of `buffer` are only
569    /// valid after the method returns `Ok`.
570    pub async fn recv_from(
571        &self,
572        buffer: &mut [u8],
573    ) -> Result<(usize, SocketAddr), ScionSocketReceiveError> {
574        let mut path_buffer = [0u8; MIN_PATH_BUFFER_SIZE];
575        let (len, sender_addr, _) = self.recv_from_with_path(buffer, &mut path_buffer).await?;
576        Ok((len, sender_addr))
577    }
578
579    /// Receive a datagram from the connected remote address.
580    ///
581    /// Datagrams from other addresses are silently discarded.
582    ///
583    /// # Cancel safety
584    ///
585    /// This method is cancel-safe. If the future is dropped while waiting for a packet, no
586    /// packet is permanently lost — the underlying receive is cancel-safe and an undelivered
587    /// packet remains available for the next call. Note that packets from other senders are
588    /// discarded during filtering; those discarded packets are not recoverable regardless of
589    /// cancellation. The contents of `buffer` are only valid after the method returns `Ok(n)`.
590    pub async fn recv(&self, buffer: &mut [u8]) -> Result<usize, ScionSocketReceiveError> {
591        if self.remote_addr.is_none() {
592            return Err(ScionSocketReceiveError::NotConnected);
593        }
594        loop {
595            let (len, sender_addr) = self.recv_from(buffer).await?;
596            match self.remote_addr {
597                Some(remote_addr) => {
598                    if sender_addr == remote_addr {
599                        return Ok(len);
600                    }
601                }
602                None => return Err(ScionSocketReceiveError::NotConnected),
603            }
604        }
605    }
606
607    /// Returns the local socket address.
608    pub fn local_addr(&self) -> SocketAddr {
609        self.socket.local_addr()
610    }
611}
612
613// Allow using `UdpScionSocket` as a `GenericScionUdpSocket` for compatibility with QUIC and HTTP/3
614// implementations.
615#[async_trait::async_trait]
616impl<P: PathManager + Sync + Send + 'static> GenericScionUdpSocket for UdpScionSocket<P> {
617    /// Asynchronously sends a Datagram to the specified destination address.
618    async fn send_to(
619        &self,
620        payload: &[u8],
621        destination: SocketAddr,
622    ) -> Result<(), BoxedSocketError> {
623        self.send_to(payload, destination)
624            .await
625            .map_err(|e| Box::new(e) as BoxedSocketError)
626    }
627
628    /// Asynchronously receives a Datagram, writing it into the provided buffer, and returns the
629    /// number of bytes read and the source address.
630    async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), BoxedSocketError> {
631        self.recv_from(buf)
632            .await
633            .map_err(|e| Box::new(e) as BoxedSocketError)
634    }
635
636    /// Returns the local socket address of this socket.
637    fn local_addr(&self) -> SocketAddr {
638        self.local_addr()
639    }
640}
641
642#[cfg(test)]
643mod cancel_safety_tests {
644    //! Unit tests verifying that all async methods on [`UdpScionSocket`] and
645    //! [`PathUnawareUdpScionSocket`] are cancel-safe.
646    //!
647    //! The tests use two hand-rolled test doubles rather than the real underlay and path manager:
648    //!
649    //! - [`ManualUnderlaySocket`]: backed by a bounded `tokio::sync::mpsc` channel. Injecting
650    //!   packets is done via the paired `Sender`. The `recv` future is backed by
651    //!   `tokio::sync::mpsc::Receiver::recv()`, which IS cancel-safe (the message stays in the
652    //!   channel if the future is dropped before returning `Ready`).
653    //!
654    //! - [`ImmediatePathManager`]: always returns a local (empty) path immediately, so tests do not
655    //!   depend on any background task.
656    //!
657    //! ## What these tests verify
658    //!
659    //! The tests verify that dropping a future at realistically reachable await points (the inner
660    //! underlay `recv`) leaves no corrupted socket state and that unconsumed packets remain
661    //! available for the next caller. They also verify that the wrong-sender filtering loop in
662    //! [`UdpScionSocket::recv`] can be safely cancelled mid-iteration.
663    //!
664    //! Because all processing steps after the underlay `recv` resolves run synchronously within
665    //! the same `poll()` invocation, there is no intermediate await point between "data received"
666    //! and "data returned" that could be independently cancelled. The tests therefore focus on
667    //! the cancel points that actually exist at runtime.
668
669    use std::{
670        io,
671        net::Ipv4Addr,
672        sync::{Arc, Mutex},
673    };
674
675    use bytes::Bytes;
676    use chrono::{DateTime, Utc};
677    use futures::future::BoxFuture;
678    use scion_proto::{
679        address::{Asn, Isd, IsdAsn, ScionAddr, SocketAddr},
680        packet::ScionPacketRaw,
681        path::{Path, test_builder::TestPathBuilder},
682    };
683
684    use super::*;
685    use crate::{
686        path::manager::traits::{PathWaitError, SyncPathManager},
687        scionstack::{ScionSocketReceiveError, ScionSocketSendError, UnderlaySocket},
688        types::{ResFut, Subscribers},
689    };
690
691    struct ManualUnderlaySocket {
692        local: SocketAddr,
693        rx: tokio::sync::Mutex<tokio::sync::mpsc::Receiver<ScionPacketRaw>>,
694    }
695
696    impl ManualUnderlaySocket {
697        fn new(local: SocketAddr) -> (Self, tokio::sync::mpsc::Sender<ScionPacketRaw>) {
698            // Use a large bounded channel so tests never block on send.
699            let (inject_tx, recv_rx) = tokio::sync::mpsc::channel::<ScionPacketRaw>(64);
700            let socket = Self {
701                local,
702                rx: tokio::sync::Mutex::new(recv_rx),
703            };
704            (socket, inject_tx)
705        }
706    }
707
708    impl UnderlaySocket for ManualUnderlaySocket {
709        fn send<'a>(
710            &'a self,
711            _packet: ScionPacketRaw,
712        ) -> BoxFuture<'a, Result<(), ScionSocketSendError>> {
713            Box::pin(async move { Ok(()) })
714        }
715
716        fn try_send(&self, _packet: ScionPacketRaw) -> Result<(), ScionSocketSendError> {
717            Ok(())
718        }
719
720        fn recv<'a>(&'a self) -> BoxFuture<'a, Result<ScionPacketRaw, ScionSocketReceiveError>> {
721            Box::pin(async move {
722                let packet = self.rx.lock().await.recv().await.ok_or_else(|| {
723                    ScionSocketReceiveError::IoError(io::Error::other("channel closed"))
724                })?;
725                Ok(packet)
726            })
727        }
728
729        fn local_addr(&self) -> SocketAddr {
730            self.local
731        }
732
733        fn snap_data_plane(&self) -> Option<std::net::SocketAddr> {
734            None
735        }
736    }
737
738    #[derive(Default)]
739    struct ImmediatePathManager {
740        registered_paths: Mutex<Vec<Path<Bytes>>>,
741    }
742
743    impl SyncPathManager for ImmediatePathManager {
744        fn register_path(
745            &self,
746            _src: IsdAsn,
747            _dst: IsdAsn,
748            _now: DateTime<Utc>,
749            path: Path<Bytes>,
750        ) {
751            self.registered_paths.lock().expect("poisoned").push(path);
752        }
753
754        fn try_cached_path(
755            &self,
756            src: IsdAsn,
757            _dst: IsdAsn,
758            _now: DateTime<Utc>,
759        ) -> io::Result<Option<Path<Bytes>>> {
760            Ok(Some(Path::local(src)))
761        }
762    }
763
764    impl PathManager for ImmediatePathManager {
765        fn path_wait(
766            &self,
767            src: IsdAsn,
768            _dst: IsdAsn,
769            _now: DateTime<Utc>,
770        ) -> impl ResFut<'_, Path<Bytes>, PathWaitError> {
771            async move { Ok(Path::local(src)) }
772        }
773    }
774
775    const LOCAL_ISD_ASN: IsdAsn = IsdAsn::new(Isd(1), Asn(1));
776    const REMOTE_ISD_ASN: IsdAsn = IsdAsn::new(Isd(1), Asn(2));
777    const OTHER_ISD_ASN: IsdAsn = IsdAsn::new(Isd(1), Asn(3));
778
779    fn local_addr() -> SocketAddr {
780        SocketAddr::new(
781            ScionAddr::new(LOCAL_ISD_ASN, Ipv4Addr::new(127, 0, 0, 1).into()),
782            8080,
783        )
784    }
785
786    fn remote_addr() -> SocketAddr {
787        SocketAddr::new(
788            ScionAddr::new(REMOTE_ISD_ASN, Ipv4Addr::new(127, 0, 0, 2).into()),
789            9090,
790        )
791    }
792
793    fn other_addr() -> SocketAddr {
794        SocketAddr::new(
795            ScionAddr::new(OTHER_ISD_ASN, Ipv4Addr::new(127, 0, 0, 3).into()),
796            7070,
797        )
798    }
799
800    /// Build a [`TestPathContext`] carrying a path from `src` to `dst`.
801    fn test_path_ctx(
802        src: ScionAddr,
803        dst: ScionAddr,
804    ) -> scion_proto::path::test_builder::TestPathContext {
805        TestPathBuilder::new(src, dst)
806            .using_info_timestamp(1_000_000)
807            .up()
808            .add_hop(0, 1)
809            .add_hop(1, 0)
810            .build(1_000_000)
811    }
812
813    /// Create a valid [`ScionPacketRaw`] that looks like a UDP packet from `src` to `dst`
814    /// with `payload`.
815    fn make_udp_raw(src: SocketAddr, dst: SocketAddr, payload: &[u8]) -> ScionPacketRaw {
816        let ctx = test_path_ctx(src.scion_address(), dst.scion_address());
817        ctx.scion_packet_udp(payload, src.port(), dst.port()).into()
818    }
819
820    /// Build a connected [`UdpScionSocket`] backed by the test doubles.
821    /// Returns the socket, the packet injector, and the path manager.
822    fn build_socket() -> (
823        UdpScionSocket<ImmediatePathManager>,
824        tokio::sync::mpsc::Sender<ScionPacketRaw>,
825        Arc<ImmediatePathManager>,
826    ) {
827        let (underlay, inject_tx) = ManualUnderlaySocket::new(local_addr());
828        let pather = Arc::new(ImmediatePathManager::default());
829        let path_unaware = PathUnawareUdpScionSocket::new(
830            Box::new(underlay),
831            vec![], // no SCMP handlers needed
832        );
833        let socket = UdpScionSocket::new(
834            path_unaware,
835            pather.clone(),
836            std::time::Duration::from_secs(5),
837            Subscribers::new(),
838        );
839        (socket, inject_tx, pather)
840    }
841
842    // ─── Tests ─────────────────────────────────────────────────────────────────
843
844    /// Dropping a [`recv_from_with_path`] future while it is pending (waiting in the channel)
845    /// must not consume the packet. The next call must receive that packet.
846    ///
847    /// This verifies that the underlay's `recv` future is cancel-safe: the message stays in the
848    /// channel when the outer future is dropped before returning `Ready`.
849    #[tokio::test]
850    async fn recv_from_with_path_cancel_while_pending_does_not_lose_packet() {
851        let (socket, inject_tx, _pather) = build_socket();
852
853        // Poll once — returns Pending because the channel is empty.
854        {
855            let (mut buf, mut pbuf) = ([0u8; 64], [0u8; 1024]);
856            let mut fut = std::pin::pin!(socket.recv_from_with_path(&mut buf, &mut pbuf));
857            let waker = futures::task::noop_waker();
858            let mut cx = std::task::Context::from_waker(&waker);
859            // The future must be Pending (no packet injected yet).
860            assert!(fut.as_mut().poll(&mut cx).is_pending());
861            // Drop `fut` here — the future is cancelled while pending.
862        }
863
864        // Inject the packet AFTER the first future was dropped.
865        let payload = b"cancel-safe";
866        inject_tx
867            .try_send(make_udp_raw(remote_addr(), local_addr(), payload))
868            .unwrap();
869
870        // The packet must be available to the next future.
871        let (mut buf2, mut pbuf2) = (vec![0u8; 64], vec![0u8; 1024]);
872        let (len, sender, _path) = socket
873            .recv_from_with_path(&mut buf2, &mut pbuf2)
874            .await
875            .unwrap();
876
877        assert_eq!(len, payload.len());
878        assert_eq!(&buf2[..len], payload);
879        assert_eq!(sender, remote_addr());
880    }
881
882    /// `recv` (connected socket) correctly filters wrong-sender packets and returns
883    /// the packet from the connected remote address.
884    #[tokio::test]
885    async fn recv_filters_wrong_sender_and_delivers_correct_packet() {
886        let (mut socket, inject_tx, _pather) = build_socket();
887        // Connect to remote_addr.
888        socket.remote_addr = Some(remote_addr());
889
890        // Inject wrong-sender packet first, then correct-sender packet.
891        inject_tx
892            .try_send(make_udp_raw(other_addr(), local_addr(), b"wrong"))
893            .unwrap();
894        inject_tx
895            .try_send(make_udp_raw(remote_addr(), local_addr(), b"correct"))
896            .unwrap();
897
898        let mut buf = [0u8; 64];
899        let len = socket.recv(&mut buf).await.unwrap();
900        assert_eq!(&buf[..len], b"correct");
901    }
902
903    /// After cancelling `recv` mid-filtering (a wrong-sender packet was consumed),
904    /// the socket must still be usable and must deliver subsequent correct-sender packets.
905    #[tokio::test]
906    async fn recv_cancel_during_filtering_socket_remains_usable() {
907        let (mut socket, inject_tx, _pather) = build_socket();
908        socket.remote_addr = Some(remote_addr());
909
910        // Inject only a wrong-sender packet — `recv` will consume it and loop back
911        // to await the next packet (Pending at that point).
912        inject_tx
913            .try_send(make_udp_raw(other_addr(), local_addr(), b"wrong"))
914            .unwrap();
915
916        // Poll once with a noop waker: recv processes the wrong-sender packet, finds it does not
917        // match the connected address, and loops back to yield on the inner recv (Pending).
918        // No Tokio runtime involvement is needed here — the channel already holds the packet.
919        {
920            let mut filter_buf = [0u8; 64];
921            let mut fut = std::pin::pin!(socket.recv(&mut filter_buf));
922            let waker = futures::task::noop_waker();
923            let mut cx = std::task::Context::from_waker(&waker);
924            assert!(
925                fut.as_mut().poll(&mut cx).is_pending(),
926                "recv must be Pending after consuming wrong-sender packet"
927            );
928            // Drop the future here — the wrong-sender packet has been consumed and discarded.
929        }
930
931        // Now inject a correct-sender packet and verify the socket is still usable.
932        inject_tx
933            .try_send(make_udp_raw(remote_addr(), local_addr(), b"after-cancel"))
934            .unwrap();
935
936        let mut buf = [0u8; 64];
937        let len = socket.recv(&mut buf).await.unwrap();
938        assert_eq!(&buf[..len], b"after-cancel");
939    }
940
941    /// Buffer contents are only valid after a successful `Ok` return; after a
942    /// cancel and retry the buffer must contain the correct data from the retry.
943    #[tokio::test]
944    async fn recv_from_buffer_valid_only_after_ok() {
945        let (socket, inject_tx, _pather) = build_socket();
946
947        // Pre-fill buffer with sentinel bytes.
948        let mut buf = [0xFFu8; 64];
949
950        // Cancel while pending (no packet).
951        {
952            let mut fut = std::pin::pin!(socket.recv_from(&mut buf));
953            let waker = futures::task::noop_waker();
954            let mut cx = std::task::Context::from_waker(&waker);
955            assert!(fut.as_mut().poll(&mut cx).is_pending());
956        }
957
958        // Inject a packet with known payload.
959        let payload = b"real-data";
960        inject_tx
961            .try_send(make_udp_raw(remote_addr(), local_addr(), payload))
962            .unwrap();
963
964        let (len, _sender) = socket.recv_from(&mut buf).await.unwrap();
965        assert_eq!(len, payload.len());
966        assert_eq!(
967            &buf[..len],
968            payload,
969            "buffer must contain the real payload after Ok return"
970        );
971    }
972}