seqknock_common/
socket.rs

1/*
2 * Copyright 2023 Jonas Eriksson
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use libc::{
18    __errno_location,
19    __u32,
20    c_int,
21    c_void,
22    close as libc_close,
23    connect as libc_connect,
24    in6_addr,
25    in_addr,
26    setsockopt,
27    sockaddr,
28    sockaddr_in,
29    sockaddr_in6,
30    socket,
31    AF_INET,
32    AF_INET6,
33    EINPROGRESS,
34    SOCK_STREAM,
35    SOL_TCP,
36    //TCP_SEND_QUEUE,
37    TCP_QUEUE_SEQ,
38    TCP_REPAIR,
39    TCP_REPAIR_QUEUE,
40};
41
42// Define the TCP_SEND_QUEUE constant here since it's missing from the libc-crate. It also seems to
43// be missing from musl, so not sure how that would work. In the end, it's defined by kernel.
44static TCP_SEND_QUEUE: __u32 = 2;
45
46use std::io::{Error, ErrorKind, Result};
47use std::mem::size_of;
48use std::net::SocketAddr;
49use std::net::ToSocketAddrs;
50use std::os::unix::io::FromRawFd;
51
52use log::debug;
53
54#[cfg(feature = "async")]
55use async_io;
56#[cfg(feature = "async")]
57use async_std;
58
59#[derive(Clone, Copy, PartialEq)]
60pub enum Family {
61    V4,
62    V6,
63}
64
65fn family_matches(socket_addr: &SocketAddr, family: Option<Family>) -> bool {
66    if let Some(f) = family {
67        if f == Family::V4 && !socket_addr.is_ipv4() {
68            return false;
69        }
70        if f == Family::V6 && !socket_addr.is_ipv6() {
71            return false;
72        }
73    }
74    true
75}
76
77pub fn connect<A: ToSocketAddrs>(
78    sequence_no: u32,
79    addr: A,
80    force_family: Option<Family>,
81) -> Result<std::net::TcpStream> {
82    unsafe {
83        // Lookup addresses
84        let socket_addrs = addr.to_socket_addrs()?;
85
86        // Try to connect to all addresses
87        let mut maybe_err = None;
88        for socket_addr in socket_addrs {
89            if !family_matches(&socket_addr, force_family) {
90                debug!("skipping {}, not of requested family", socket_addr);
91                continue;
92            }
93            debug!("Trying to connect to {}", socket_addr);
94
95            let sock = create_socket(family_of(&socket_addr), sequence_no)?;
96
97            match connect_socket(sock, socket_addr, true) {
98                Ok(()) => {
99                    debug!("Connected to {}", socket_addr);
100                    // Create stream from socket fd
101                    return Ok(std::net::TcpStream::from_raw_fd(sock));
102                }
103                Err(e) => maybe_err = Some(e),
104            }
105
106            libc_close(sock);
107        }
108
109        if let Some(e) = maybe_err {
110            return Err(e);
111        }
112    }
113    Err(Error::new(
114        ErrorKind::AddrNotAvailable,
115        "No address entries for hostname",
116    ))
117}
118
119#[cfg(feature = "async")]
120pub async fn connect_async<A: async_std::net::ToSocketAddrs>(
121    sequence_no: u32,
122    socket_addr: A,
123    force_family: Option<Family>,
124) -> Result<async_std::net::TcpStream> {
125    let socket_addrs = socket_addr.to_socket_addrs().await?;
126
127    unsafe {
128        // Try to connect to all addresses
129        let mut maybe_err = None;
130        for socket_addr in socket_addrs {
131            if !family_matches(&socket_addr, force_family) {
132                debug!("skipping {}, not of requested family", socket_addr);
133                continue;
134            }
135            debug!("Trying to connect to {}", socket_addr);
136
137            let sock = create_socket(family_of(&socket_addr), sequence_no)?;
138
139            match connect_socket(sock, socket_addr, true) {
140                Ok(()) => {
141                    let stream = match async_io::Async::new(std::net::TcpStream::from_raw_fd(sock))
142                    {
143                        Ok(s) => s,
144                        Err(e) => {
145                            maybe_err = Some(e);
146                            continue;
147                        }
148                    };
149                    match stream.writable().await {
150                        Ok(_) => {
151                            match stream.get_ref().take_error()? {
152                                None => {
153                                    debug!("Connected to {}", socket_addr);
154                                    // .into_inner() failing is such a rare case, don't apply
155                                    // special next-IP handling if it fails.
156                                    return Ok(stream.into_inner()?.into());
157                                }
158                                Some(e) => {
159                                    maybe_err = Some(e);
160                                    continue;
161                                }
162                            }
163                        }
164                        Err(e) => maybe_err = Some(e),
165                    }
166                }
167                Err(e) => maybe_err = Some(e),
168            }
169
170            libc_close(sock);
171        }
172
173        if let Some(e) = maybe_err {
174            return Err(e);
175        }
176    }
177
178    Err(Error::new(
179        ErrorKind::AddrNotAvailable,
180        "No address entries for hostname",
181    ))
182}
183
184unsafe fn connect_socket(
185    sock: c_int,
186    socket_addr: std::net::SocketAddr,
187    blocking: bool,
188) -> Result<()> {
189    match socket_addr {
190        SocketAddr::V4(v4addr) => {
191            let octets = v4addr.ip().octets();
192            let u32_addr: u32 = (octets[0] as u32)
193                | (octets[1] as u32) << 8
194                | (octets[2] as u32) << 16
195                | (octets[3] as u32) << 24;
196            let saddr = sockaddr_in {
197                sin_family: AF_INET as u16,
198                sin_port: v4addr.port().to_be(),
199                sin_addr: in_addr { s_addr: u32_addr },
200                sin_zero: [0; 8],
201            };
202            let result = libc_connect(
203                sock,
204                &saddr as *const sockaddr_in as *const sockaddr,
205                size_of::<sockaddr_in>() as u32,
206            );
207            if result < 0 && (blocking || (*__errno_location()) != EINPROGRESS) {
208                return Err(Error::last_os_error());
209            }
210        }
211        SocketAddr::V6(v6addr) => {
212            let saddr = sockaddr_in6 {
213                sin6_family: AF_INET6 as u16,
214                sin6_port: v6addr.port().to_be(),
215                sin6_flowinfo: 0,
216                sin6_addr: in6_addr {
217                    s6_addr: v6addr.ip().octets(),
218                },
219                sin6_scope_id: 0,
220            };
221            let result = libc_connect(
222                sock,
223                &saddr as *const sockaddr_in6 as *const sockaddr,
224                size_of::<sockaddr_in6>() as u32,
225            );
226            if result < 0 && (blocking || (*__errno_location()) != EINPROGRESS) {
227                return Err(Error::last_os_error());
228            }
229        }
230    }
231
232    Ok(())
233}
234
235unsafe fn sso_tcp_wrapper(sock: c_int, cmd: c_int, data: u32) -> Result<()> {
236    let dataptr = &data as *const __u32 as *const c_void;
237    if setsockopt(sock, SOL_TCP, cmd, dataptr, 4) < 0 {
238        return Err(Error::last_os_error());
239    }
240    Ok(())
241}
242
243unsafe fn create_socket(family: c_int, sequence_no: u32) -> Result<c_int> {
244    // Create socket
245    let sock: c_int = socket(family, SOCK_STREAM, 0);
246    if sock < 0 {
247        return Err(Error::last_os_error());
248    }
249
250    // Enter repair mode
251    sso_tcp_wrapper(sock, TCP_REPAIR, 1)?;
252    // Enter repair queue mode for the send queue
253    sso_tcp_wrapper(sock, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE)?;
254    // Set sequence number
255    sso_tcp_wrapper(sock, TCP_QUEUE_SEQ, sequence_no)?;
256    // Exit repair mode
257    sso_tcp_wrapper(sock, TCP_REPAIR, 0)?;
258
259    Ok(sock)
260}
261
262fn family_of(socket_addr: &std::net::SocketAddr) -> c_int {
263    match socket_addr {
264        SocketAddr::V4(_) => AF_INET,
265        SocketAddr::V6(_) => AF_INET6,
266    }
267}