Skip to main content

s2n_quic_platform/message/msg/
handle.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::ext::Ext as _;
5use crate::{features, message::cmsg::Encoder};
6use libc::msghdr;
7use s2n_quic_core::{
8    ensure,
9    inet::{AncillaryData, SocketAddressV4, Unspecified},
10    path::{self, LocalAddress, RemoteAddress},
11};
12
13#[cfg(any(test, feature = "generator"))]
14use bolero_generator::*;
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17#[cfg_attr(any(test, feature = "generator"), derive(TypeGenerator))]
18pub struct Handle {
19    pub remote_address: RemoteAddress,
20    pub local_address: LocalAddress,
21}
22
23impl Handle {
24    #[inline]
25    pub(super) fn with_ancillary_data(&mut self, ancillary_data: AncillaryData) {
26        self.local_address = ancillary_data.local_address;
27    }
28
29    #[inline]
30    pub(super) fn update_msg_hdr(&self, msghdr: &mut msghdr) {
31        // when sending a packet, we start out with no cmsg items
32        msghdr.msg_controllen = 0;
33
34        msghdr.set_remote_address(&self.remote_address.0);
35
36        msghdr
37            .cmsg_encoder()
38            .encode_local_address(&self.local_address.0)
39            .unwrap();
40    }
41}
42
43impl path::Handle for Handle {
44    #[inline]
45    fn from_remote_address(remote_address: RemoteAddress) -> Self {
46        Self {
47            remote_address,
48            local_address: SocketAddressV4::UNSPECIFIED.into(),
49        }
50    }
51
52    #[inline]
53    fn remote_address(&self) -> RemoteAddress {
54        self.remote_address
55    }
56
57    #[inline]
58    fn set_remote_address(&mut self, addr: RemoteAddress) {
59        self.remote_address = addr;
60    }
61
62    #[inline]
63    fn local_address(&self) -> LocalAddress {
64        self.local_address
65    }
66
67    #[inline]
68    fn set_local_address(&mut self, addr: LocalAddress) {
69        self.local_address = addr;
70    }
71
72    #[inline]
73    fn unmapped_eq(&self, other: &Self) -> bool {
74        ensure!(
75            self.remote_address.unmapped_eq(&other.remote_address),
76            false
77        );
78
79        // only compare local addresses if the OS returns them
80        ensure!(features::pktinfo::IS_SUPPORTED, true);
81
82        // Make sure to only compare the fields if they're both set
83        //
84        // This avoids cases where we don't have the full context for the local address and find it
85        // out with a later packet.
86        if !self.local_address.ip().is_unspecified() && !other.local_address.ip().is_unspecified() {
87            ensure!(
88                self.local_address
89                    .ip()
90                    .unmapped_eq(&other.local_address.ip()),
91                false
92            );
93        }
94
95        if self.local_address.port() > 0 && other.local_address.port() > 0 {
96            ensure!(
97                self.local_address.port() == other.local_address.port(),
98                false
99            );
100        }
101
102        true
103    }
104
105    #[inline]
106    fn strict_eq(&self, other: &Self) -> bool {
107        PartialEq::eq(self, other)
108    }
109
110    #[inline]
111    fn maybe_update(&mut self, other: &Self) {
112        self.local_address.maybe_update(&other.local_address);
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use crate::message::msg::Handle;
119    use s2n_quic_core::{
120        inet::{IpAddress, IpV4Address},
121        path::{Handle as _, LocalAddress},
122    };
123
124    /// Checks that unmapped_eq is correct independent of argument ordering
125    fn reflexive_check(a: Handle, b: Handle) {
126        assert!(a.unmapped_eq(&b));
127        assert!(b.unmapped_eq(&a));
128    }
129
130    #[test]
131    fn unmapped_eq_test() {
132        // All of these values should be considered equivalent for local addresses
133        let ips: &[IpAddress] = &[
134            // if we have an unspecified IP address then don't consider it for equality
135            IpV4Address::new([0, 0, 0, 0]).into(),
136            // a regular IPv4 IP should match the IPv4-mapped into IPv6
137            IpV4Address::new([1, 1, 1, 1]).into(),
138            IpV4Address::new([1, 1, 1, 1]).to_ipv6_mapped().into(),
139        ];
140        let ports = [0u16, 4440];
141
142        for ip_a in ips {
143            for ip_b in ips {
144                for port_a in ports {
145                    for port_b in ports {
146                        reflexive_check(
147                            Handle {
148                                remote_address: Default::default(),
149                                local_address: LocalAddress::from(ip_a.with_port(port_a)),
150                            },
151                            Handle {
152                                remote_address: Default::default(),
153                                local_address: LocalAddress::from(ip_b.with_port(port_b)),
154                            },
155                        );
156                    }
157                }
158            }
159        }
160    }
161}