uni_socket/unix/
splice.rs

1//! Linux `splice(2)` support
2
3use std::num::NonZeroUsize;
4use std::os::fd::AsFd;
5use std::task::{ready, Poll};
6use std::{io, mem, task};
7
8use rustix::event::PollFlags;
9use rustix::fd::OwnedFd;
10use rustix::pipe::{
11    fcntl_getpipe_size, fcntl_setpipe_size, pipe_with, splice, PipeFlags, SpliceFlags,
12};
13
14use super::{OwnedReadHalf, OwnedWriteHalf, UniStream};
15
16#[derive(Debug)]
17/// `splice(2)` operation context.
18pub struct Context {
19    /// Pipe used to splice data.
20    pipe: Pipe,
21
22    /// The `off_in` when splicing from `R` to the pipe, or the `off_out` when
23    /// splicing from the pipe to `W`.
24    offset: Offset,
25
26    /// Target bytes to splice from `R` to `W`.
27    ///
28    /// Default is `isize::MAX`, which means read as much as possible. This is
29    /// also the maximum length that can be spliced in a single `splice(2)`
30    /// call.
31    target: usize,
32
33    /// Total bytes drained from `R`.
34    drained: usize,
35
36    /// Total bytes pumped to `W`.
37    pumped: usize,
38
39    /// The current state of the splice operation.
40    state: State,
41}
42
43#[derive(Debug)]
44enum State {
45    Draining,
46    Pumping,
47    Done,
48    Error,
49}
50
51impl Context {
52    /// Create a new [`Context`] with given pipe.
53    pub fn new(pipe: Pipe) -> Self {
54        Self {
55            pipe,
56            offset: Offset::None,
57            target: isize::MAX as usize,
58            drained: 0,
59            pumped: 0,
60            state: State::Draining,
61        }
62    }
63
64    /// Polls copying data from `src` to `dst` using `splice(2)`.
65    ///
66    /// Notes that on multiple calls to [`poll_copy`](Self::poll_copy), only the
67    /// [`Waker`](task::Waker) from the [`task::Context`] passed to the most
68    /// recent call is scheduled to receive a wakeup, which is probably not what
69    /// you want.
70    pub fn poll_copy(
71        &mut self,
72        cx: &mut task::Context<'_>,
73        src: &mut OwnedReadHalf,
74        dst: &mut OwnedWriteHalf,
75    ) -> Poll<io::Result<()>> {
76        macro_rules! ret {
77            ($($tt:tt)*) => {
78                match $($tt)* {
79                    Ok(v) => v,
80                    Err(e) => {
81                        self.pipe.set_drain_done();
82                        self.pipe.set_pump_done();
83                        self.state = State::Error;
84                        return Poll::Ready(Err(e));
85                    }
86                }
87            };
88        }
89
90        loop {
91            match self.state {
92                State::Draining => {
93                    ret!(ready!(self.poll_drain(cx, src.as_ref())));
94
95                    self.state = State::Pumping;
96                }
97                State::Pumping => match ret!(ready!(self.poll_pump(cx, dst.as_ref()))) {
98                    Ret::Continue => self.state = State::Draining,
99                    Ret::Done => self.state = State::Done,
100                },
101                State::Done => {
102                    break Poll::Ready(Ok(()));
103                }
104                State::Error => {
105                    break Poll::Ready(Err(io::Error::new(
106                        io::ErrorKind::Other,
107                        "Polls after error",
108                    )));
109                }
110            }
111        }
112    }
113
114    #[inline]
115    /// Returns the total bytes transferred so far.
116    pub const fn transferred(&self) -> usize {
117        self.drained
118    }
119
120    /// `poll_drain` moves data from a socket (or file) to a pipe.
121    ///
122    /// # Invariants
123    ///
124    /// 1. The pipe must be empty, i.e., all data previously drained must have
125    ///    been pumped to the destination.
126    ///
127    /// # Behaviours
128    ///
129    /// - This will close the pipe write side when there's no more data to read
130    ///   (i.e., EOF).
131    /// - This will not keep draining. Instead, it drains once and returns.
132    fn poll_drain(
133        &mut self,
134        cx: &mut task::Context<'_>,
135        socket_r: &UniStream,
136    ) -> Poll<io::Result<Ret>> {
137        let Some(pipe_w) = self.pipe.w.as_fd() else {
138            return Poll::Ready(Ok(Ret::Done));
139        };
140
141        let target = {
142            let Some(target) = self.target.checked_sub(self.drained) else {
143                // TODO: panic or error?
144                return Poll::Ready(Err(io::Error::new(
145                    io::ErrorKind::Other,
146                    "unexpected: read more than target?",
147                )));
148            };
149
150            let Some(target) = NonZeroUsize::new(target) else {
151                self.pipe.set_drain_done();
152
153                return Poll::Ready(Ok(Ret::Done));
154            };
155
156            // Okay, we are sure that the pipe is empty now.
157            target.min(self.pipe.size)
158        };
159
160        loop {
161            let mut guard = ready!(socket_r.as_inner().poll_read_ready(cx))?;
162
163            match splice(
164                socket_r,
165                self.offset.off_in(),
166                pipe_w,
167                None,
168                target.get(),
169                SpliceFlags::NONBLOCK,
170            )
171            .map(NonZeroUsize::new)
172            .map_err(io::Error::from)
173            {
174                Ok(Some(drained)) => {
175                    self.drained += drained.get();
176
177                    return Poll::Ready(Ok(Ret::Continue));
178                }
179                Ok(None) => {
180                    self.pipe.set_drain_done();
181
182                    return Poll::Ready(Ok(Ret::Done));
183                }
184                Err(e) if matches!(e.kind(), io::ErrorKind::Interrupted) => {}
185                Err(e) if matches!(e.kind(), io::ErrorKind::WouldBlock) => {
186                    if !test_readiness(socket_r, PollFlags::IN)? {
187                        guard.clear_ready();
188
189                        continue;
190                    }
191
192                    // Actually should not reach here, we have ensured that the pipe is empty...
193                    return Poll::Ready(Ok(Ret::Continue));
194                }
195                Err(e) => return Poll::Ready(Err(e)),
196            }
197        }
198    }
199
200    /// `poll_pump` moves data from a pipe to a socket (or file).
201    ///
202    /// # Behaviours
203    ///
204    /// - This will close the pipe read side when there's no more data to write
205    ///   (i.e., the pipe write side is closed and the draining process is
206    ///   done).
207    /// - This keeps pumping until all drained data has been written to the
208    ///   destination.
209    fn poll_pump(
210        &mut self,
211        cx: &mut task::Context<'_>,
212        socket_w: &UniStream,
213    ) -> Poll<io::Result<Ret>> {
214        let Some(pipe_r) = self.pipe.r.as_fd() else {
215            return Poll::Ready(Ok(Ret::Done));
216        };
217
218        'et_loop: loop {
219            let mut guard = ready!(socket_w.as_inner().poll_write_ready(cx))?;
220
221            loop {
222                let Some(remaining) = self.drained.checked_sub(self.pumped) else {
223                    return Poll::Ready(Err(io::Error::new(
224                        io::ErrorKind::Other,
225                        "unexpected: written more than read?",
226                    )));
227                };
228
229                let Some(remaining) = NonZeroUsize::new(remaining) else {
230                    if self.pipe.is_drain_done() {
231                        self.pipe.set_pump_done();
232                        return Poll::Ready(Ok(Ret::Done));
233                    } else {
234                        return Poll::Ready(Ok(Ret::Continue));
235                    }
236                };
237
238                match splice(
239                    pipe_r,
240                    None,
241                    socket_w,
242                    self.offset.off_out(),
243                    remaining.get(),
244                    SpliceFlags::NONBLOCK,
245                )
246                .map(NonZeroUsize::new)
247                .map_err(io::Error::from)
248                {
249                    Ok(Some(written)) => self.pumped += written.get(),
250                    Err(e) if matches!(e.kind(), io::ErrorKind::WouldBlock) => {
251                        if !test_readiness(socket_w, PollFlags::OUT)? {
252                            guard.clear_ready();
253
254                            continue 'et_loop;
255                        }
256                    }
257                    Ok(None) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
258                    Err(e) if matches!(e.kind(), io::ErrorKind::Interrupted) => {}
259                    Err(e) => return Poll::Ready(Err(e)),
260                }
261            }
262        }
263    }
264}
265
266#[inline]
267/// When we call `splice(2)` in blocking mode, although the socket itself has
268/// been set to non blocking mode, `splice(2)` operation itself may still block,
269/// which is unacceptable in asynchronous context (although some tests can still
270/// pass). For optimal compatibility, we set non blocking mode explicitly
271/// (`O_NONBLOCK` to the pipe, while `SPLICE_F_NONBLOCK` to `splice(2)`) [like
272/// Go does].
273///
274/// However, since `tokio` will probably rely on *Edge Trigger* mode, once the
275/// readiness flag of a socket on the `tokio` side is incorrectly cleared then
276/// loses sync, the `tokio` I/O task may sleep forever. Although by increasing
277/// the pipe size, limiting the number of bytes read from the src socket to
278/// the pipe in a single turn, and not [`drain`ing](Context::poll_drain) before
279/// fully [`pumping`](Context::poll_pump) bytes we has read, we can alleviate
280/// this problem, we adopt this workaround to ensure correctness: when
281/// `splice(2)` returns `EAGAIN`, we poll the socket's readiness manually to
282/// confirm whether the socket is indeed not ready. Poll always uses
283/// level-triggered mode and it does not require any registration at all.
284///
285/// Some of benchmarks did show that this workaround has little impact on
286/// performance.
287///
288/// Here's some useful references:
289///
290/// - [MengJiangProject/redproxy-rs](https://github.com/MengJiangProject/redproxy-rs/blob/4363b868bce8449441fb6364a679948d73270465/src/common/splice.rs)
291/// - [Short-read optimization is wrong for O_DIRECT pipes](https://github.com/tokio-rs/tokio/issues/7051)
292/// - [Reduce Mio's portable API to only support edge triggered notifications](https://github.com/tokio-rs/mio/issues/928)
293/// - [Linux I/O 栈与零拷贝技术全揭秘 (Chinese)](https://strikefreedom.top/archives/linux-io-stack-and-zero-copy#splice)
294/// - [Linux epoll 之 LT & ET 模式精粹 (Chinese)](https://strikefreedom.top/archives/linux-epoll-with-level-triggering-and-edge-triggering)
295///
296/// [like Go does]: https://github.com/golang/go/blob/master/src/internal/poll/splice_linux.go
297fn test_readiness(socket: &impl AsFd, flag: PollFlags) -> io::Result<bool> {
298    use rustix::event::{poll, PollFd, Timespec};
299
300    let pollfds = &mut [PollFd::new(socket, flag)];
301
302    // Set a timeout of 0, returning immediately.
303    poll(pollfds, Some(&Timespec::default()))?;
304
305    Ok(match pollfds[0].revents() {
306        PollFlags::ERR | PollFlags::HUP | PollFlags::IN => true,
307        PollFlags::NVAL => {
308            return Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid fd"));
309        }
310        _ => false,
311    })
312}
313
314#[derive(Debug)]
315enum Offset {
316    None,
317
318    #[allow(unused)]
319    /// Read offset set.
320    In(Option<u64>),
321
322    #[allow(unused)]
323    /// Write offset set.
324    Out(Option<u64>),
325}
326
327impl Offset {
328    #[inline]
329    fn off_in(&mut self) -> Option<&mut u64> {
330        match self {
331            Offset::In(off) => off.as_mut(),
332            _ => None,
333        }
334    }
335
336    #[inline]
337    fn off_out(&mut self) -> Option<&mut u64> {
338        match self {
339            Offset::Out(off) => off.as_mut(),
340            _ => None,
341        }
342    }
343
344    #[allow(unused)]
345    fn calc_size_to_splice(
346        f_len: u64,
347        f_offset_start: Option<u64>,
348        f_offset_end: Option<u64>,
349    ) -> io::Result<u64> {
350        match (f_offset_start, f_offset_end) {
351            (Some(start), Some(end)) => {
352                if start > end || end > f_len {
353                    return Err(io::Error::new(
354                        io::ErrorKind::InvalidInput,
355                        "invalid offset range",
356                    ));
357                }
358                Ok(end - start)
359            }
360            (Some(start), None) => {
361                if start > f_len {
362                    return Err(io::Error::new(
363                        io::ErrorKind::InvalidInput,
364                        "invalid offset start",
365                    ));
366                }
367                Ok(f_len - start)
368            }
369            (None, Some(end)) => {
370                if end > f_len {
371                    return Err(io::Error::new(
372                        io::ErrorKind::InvalidInput,
373                        "invalid offset end",
374                    ));
375                }
376                Ok(end)
377            }
378            (None, None) => Ok(f_len),
379        }
380    }
381}
382
383/// `MAXIMUM_PIPE_SIZE` is the maximum amount of data we ask the kernel to move
384/// in a single call to `splice(2)`.
385///
386/// We use 1MB as `splice(2)` writes data through a pipe, and 1MB is the default
387/// maximum pipe buffer size, which is determined by
388/// `/proc/sys/fs/pipe-max-size`.
389///
390/// Running applications under unprivileged user may have the pages usage
391/// limited. See [`pipe(7)`] for details.
392///
393/// [`pipe(7)`]: https://man7.org/linux/man-pages/man7/pipe.7.html
394const MAXIMUM_PIPE_SIZE: usize = 1 << 20;
395
396/// By default, 16 * PAGE_SIZE.
397const DEFAULT_PIPE_SIZE: usize = 1 << 16;
398
399#[derive(Debug)]
400/// `Pipe`.
401pub struct Pipe {
402    /// File descriptor for reading from the pipe
403    r: Fd,
404
405    /// File descriptor for writing to the pipe
406    w: Fd,
407
408    /// Pipe size in bytes.
409    size: NonZeroUsize,
410}
411
412#[derive(Debug)]
413enum Fd {
414    /// The file descriptor can be used for reading or writing.
415    Running(OwnedFd),
416
417    #[allow(unused)]
418    /// The file descriptor is reserved for future use (to be recycled).
419    Reserved(OwnedFd),
420
421    /// Make compiler happy.
422    Closed,
423}
424
425impl Fd {
426    #[inline]
427    /// Convert the file descriptor to a pending state.
428    ///
429    /// This is used to indicate that the file descriptor is reserved for future
430    /// use.
431    fn set_reserved(&mut self) {
432        if let Fd::Running(owned_fd) = mem::replace(self, Fd::Closed) {
433            *self = Fd::Reserved(owned_fd);
434        }
435    }
436
437    #[inline]
438    const fn as_fd(&self) -> Option<&OwnedFd> {
439        match self {
440            Fd::Running(fd) => Some(fd),
441            _ => None,
442        }
443    }
444}
445
446impl Pipe {
447    /// Create a [`Pipe`], with flags `O_CLOEXEC`.
448    ///
449    /// The default pipe size is set to `MAXIMUM_PIPE_SIZE` bytes.
450    ///
451    /// # Errors
452    ///
453    /// See [`pipe(2)`] and [`fcntl(2)`].
454    ///
455    /// [`pipe(2)`]: https://man7.org/linux/man-pages/man2/pipe.2.html
456    /// [`fcntl(2)`]: https://man7.org/linux/man-pages/man2/fcntl.2.html
457    pub fn new() -> io::Result<Self> {
458        pipe_with(PipeFlags::CLOEXEC | PipeFlags::NONBLOCK)
459            .map_err(|e| io::Error::from_raw_os_error(e.raw_os_error()))
460            .map(|(r, w)| Self {
461                r: Fd::Running(r),
462                w: Fd::Running(w),
463                size: NonZeroUsize::new(DEFAULT_PIPE_SIZE).unwrap(),
464            })
465            .and_then(|mut this| {
466                // Splice will loop writing MAXIMUM_PIPE_SIZE bytes from the source to the pipe,
467                // and then write those bytes from the pipe to the destination.
468                // Set the pipe buffer size to MAXIMUM_PIPE_SIZE to optimize that.
469                // Ignore errors here, as a smaller buffer size will work,
470                // although it will require more system calls.
471
472                this.update_pipe_size(NonZeroUsize::new(MAXIMUM_PIPE_SIZE).unwrap())?;
473
474                Ok(this)
475            })
476    }
477
478    /// Sets and updates the pipe size.
479    ///
480    /// ## Errors
481    ///
482    /// See [`fcntl(2)`].
483    ///
484    /// [`fcntl(2)`]: https://man7.org/linux/man-pages/man2/fcntl.2.html.
485    pub fn update_pipe_size(&mut self, size: NonZeroUsize) -> io::Result<usize> {
486        let Some(r) = self.r.as_fd() else {
487            return Err(io::Error::new(
488                io::ErrorKind::Other,
489                "pipe is not available",
490            ));
491        };
492
493        match fcntl_setpipe_size(&r, size.get()).map(NonZeroUsize::new) {
494            Ok(Some(size)) => {
495                self.size = size;
496
497                Ok(self.size.get())
498            }
499            Ok(None) => {
500                self.size = fcntl_getpipe_size(&r)
501                    .ok()
502                    .and_then(NonZeroUsize::new)
503                    .ok_or_else(|| {
504                        io::Error::new(
505                            io::ErrorKind::Other,
506                            "Failed to get pipe size while fcntl returned zero",
507                        )
508                    })?;
509
510                Err(io::Error::new(io::ErrorKind::Other, "fcntl returned zero"))
511            }
512            Err(e) => {
513                self.size = fcntl_getpipe_size(&r)
514                    .ok()
515                    .and_then(NonZeroUsize::new)
516                    .ok_or_else(|| {
517                        io::Error::new(
518                            io::ErrorKind::Other,
519                            format!("Failed to get pipe size while fcntl returned error: {e}"),
520                        )
521                    })?;
522
523                Err(io::Error::from_raw_os_error(e.raw_os_error()))
524            }
525        }
526    }
527
528    #[inline]
529    const fn is_drain_done(&self) -> bool {
530        matches!(self.w, Fd::Reserved(_) | Fd::Closed)
531    }
532
533    #[inline]
534    /// Close the pipe write side file descriptor.
535    fn set_drain_done(&mut self) {
536        self.w.set_reserved();
537    }
538
539    #[inline]
540    #[allow(unused)]
541    const fn is_pump_done(&self) -> bool {
542        matches!(self.r, Fd::Reserved(_) | Fd::Closed)
543    }
544
545    #[inline]
546    /// Close the pipe read side file descriptor.
547    fn set_pump_done(&mut self) {
548        self.r.set_reserved();
549    }
550
551    #[inline]
552    /// Returns the capacity of the pipe, in bytes.
553    pub const fn size(&self) -> NonZeroUsize {
554        self.size
555    }
556}
557
558enum Ret {
559    /// Has drained some data from source to pipe and can continue to pump the
560    /// data to the destination; or all drained data has been pumped and can
561    /// continue to drain more data.
562    Continue,
563
564    /// The draining / pumping process is done.
565    ///
566    /// For draining, it means no more data to read (i.e., EOF), but there may
567    /// still be data in the pipe to pump; for pumping, it means all drained
568    /// data has been pumped.
569    Done,
570}