s2n_quic_platform/syscall/
msg.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{SocketEvents, SocketType, UnixMessage};
5use crate::{message::Message as _, socket::stats};
6use libc::msghdr;
7use std::os::unix::io::{AsRawFd, RawFd};
8
9impl UnixMessage for msghdr {
10    #[inline]
11    fn send<E: SocketEvents>(
12        fd: RawFd,
13        entries: &mut [Self],
14        events: &mut E,
15        stats: &stats::Sender,
16    ) {
17        send(&fd, entries, events, stats)
18    }
19
20    #[inline]
21    fn recv<E: SocketEvents>(
22        fd: RawFd,
23        ty: SocketType,
24        entries: &mut [Self],
25        events: &mut E,
26        stats: &stats::Sender,
27    ) {
28        recv(&fd, ty, entries, events, stats)
29    }
30}
31
32#[inline]
33pub fn send<'a, Sock: AsRawFd, P: IntoIterator<Item = &'a mut msghdr>, E: SocketEvents>(
34    socket: &Sock,
35    packets: P,
36    events: &mut E,
37    stats: &stats::Sender,
38) {
39    for packet in packets {
40        #[cfg(debug_assertions)]
41        let prev_msg_control_ptr = packet.msg_control;
42
43        // macOS doesn't like when msg_control have valid pointers but the len is 0
44        //
45        // If that's the case here, then set the `msg_control` to null and restore it after
46        // calling sendmsg.
47        #[cfg(any(target_os = "macos", target_os = "ios", test))]
48        let msg_control = {
49            let msg_control = packet.msg_control;
50
51            if packet.msg_controllen == 0 {
52                packet.msg_control = core::ptr::null_mut();
53            }
54
55            msg_control
56        };
57
58        // Safety: calling a libc function is inherently unsafe as rust cannot
59        // make any invariant guarantees. This has to be reviewed by humans instead
60        // so the [docs](https://linux.die.net/man/2/sendmsg) are inlined here:
61
62        // > The argument sockfd is the file descriptor of the sending socket.
63        let sockfd = socket.as_raw_fd();
64
65        // > The address of the target is given by msg.msg_name, with msg.msg_namelen
66        // > specifying its size.
67        //
68        // > The message is pointed to by the elements of the array msg.msg_iov.
69        // > The sendmsg() call also allows sending ancillary data (also known as
70        // > control information).
71        let msg = packet;
72
73        // > The flags argument is the bitwise OR of zero or more flags.
74        //
75        // No flags are currently set
76        let flags = Default::default();
77
78        // > On success, these calls return the number of characters sent.
79        // > On error, -1 is returned, and errno is set appropriately.
80        let result = libc!(sendmsg(sockfd, msg, flags));
81
82        // restore the msg_control pointer if needed
83        #[cfg(any(target_os = "macos", target_os = "ios", test))]
84        {
85            msg.msg_control = msg_control;
86        }
87
88        #[cfg(debug_assertions)]
89        {
90            assert_eq!(
91                prev_msg_control_ptr, msg.msg_control,
92                "msg_control pointer was modified by the OS"
93            );
94        }
95
96        stats.send().on_operation_result(&result, |_len| 1);
97
98        let cf = match result {
99            Ok(_) => events.on_complete(1),
100            Err(err) => events.on_error(err),
101        };
102
103        if cf.is_break() {
104            return;
105        }
106    }
107}
108
109#[inline]
110pub fn recv<'a, Sock: AsRawFd, P: IntoIterator<Item = &'a mut msghdr>, E: SocketEvents>(
111    socket: &Sock,
112    socket_type: SocketType,
113    packets: P,
114    events: &mut E,
115    stats: &stats::Sender,
116) {
117    let mut flags = match socket_type {
118        SocketType::Blocking => Default::default(),
119        SocketType::NonBlocking => libc::MSG_DONTWAIT,
120    };
121
122    for packet in packets {
123        #[cfg(debug_assertions)]
124        let prev_msg_control_ptr = packet.msg_control;
125
126        // Safety: calling a libc function is inherently unsafe as rust cannot
127        // make any invariant guarantees. This has to be reviewed by humans instead
128        // so the [docs](https://linux.die.net/man/2/recmsg) are inlined here:
129
130        // > The argument sockfd is the file descriptor of the receiving socket.
131        let sockfd = socket.as_raw_fd();
132
133        // > The recvmsg() call uses a msghdr structure to minimize the number of
134        // > directly supplied arguments.
135        //
136        // > Here msg_name and msg_namelen specify the source address if the
137        // > socket is unconnected.
138        //
139        // > The fields msg_iov and msg_iovlen describe scatter-gather locations
140        //
141        // > When recvmsg() is called, msg_controllen should contain the length
142        // > of the available buffer in msg_control; upon return from a successful
143        // > call it will contain the length of the control message sequence.
144        let msg = packet;
145
146        // > The flags argument to a recv() call is formed by ORing one or more flags
147        //
148        // We set MSG_DONTWAIT if it's nonblocking or there is more than one call
149
150        // > recvmsg() calls are used to receive messages from a socket
151        //
152        // > All three routines return the length of the message on successful completion.
153        // > If a message is too long to fit in the supplied buffer, excess bytes may be
154        // > discarded depending on the type of socket the message is received from.
155        //
156        // > These calls return the number of bytes received, or -1 if an error occurred.
157        let result = libc!(recvmsg(sockfd, msg, flags));
158
159        #[cfg(debug_assertions)]
160        {
161            assert_eq!(
162                prev_msg_control_ptr, msg.msg_control,
163                "msg_control pointer was modified by the OS"
164            );
165        }
166
167        stats.recv().on_operation_result(&result, |_len| 1);
168
169        let cf = match result {
170            Ok(payload_len) => {
171                // update the message based on the return size of the syscall
172                unsafe {
173                    msg.set_payload_len(payload_len.min(u16::MAX as _).max(0) as _);
174                }
175                events.on_complete(1)
176            }
177            Err(err) => events.on_error(err),
178        };
179
180        if cf.is_break() {
181            return;
182        }
183
184        // don't block the follow-up calls
185        flags = libc::MSG_DONTWAIT;
186    }
187}