tokio_splice2/
io.rs

1//! `splice(2)` I/O implementation.
2
3use std::future::poll_fn;
4use std::marker::PhantomData;
5#[cfg(not(feature = "feat-rate-limit"))]
6use std::marker::PhantomPinned;
7use std::os::fd::AsFd;
8use std::pin::{pin, Pin};
9use std::task::{ready, Context, Poll};
10use std::{io, ops};
11
12use crossbeam_utils::CachePadded;
13use tokio::fs::File;
14use tokio::io::{AsyncRead, AsyncWrite, Interest};
15use tokio::net::{TcpStream, UnixStream};
16#[cfg(feature = "feat-rate-limit")]
17use tokio::time::Sleep;
18
19use crate::context::SpliceIoCtx;
20use crate::rate::RATE_LIMITER_DISABLED;
21#[cfg(feature = "feat-rate-limit")]
22use crate::rate::{RateLimit, RateLimitResult, RateLimiter, RATE_LIMITER_ENABLED};
23use crate::traffic::TrafficResult;
24use crate::utils::Drained;
25
26#[pin_project::pin_project]
27#[derive(Debug)]
28/// Zero-copy unidirectional I/O with `splice(2)`.
29///
30/// For bidirectional I/O version, see [`SpliceBidiIo`].
31///
32/// Notice: see the [module-level documentation](crate) for known limitations.
33pub struct SpliceIo<R, W, const RATE_LIMITER_IS_ENABLED: bool = RATE_LIMITER_DISABLED> {
34    /// Context for the splice I/O operation.
35    ///
36    /// See [`SpliceIoCtx`] for more details.
37    ctx: CachePadded<SpliceIoCtx<R, W>>,
38
39    r: PhantomData<R>,
40    w: PhantomData<W>,
41
42    #[cfg(feature = "feat-rate-limit")]
43    /// To limit the transfer speed.
44    rate_limiter: RateLimiter<RATE_LIMITER_IS_ENABLED>,
45
46    #[pin]
47    state: TransferState,
48}
49
50impl<R, W, const RATE_LIMITER_IS_ENABLED: bool> ops::Deref
51    for SpliceIo<R, W, RATE_LIMITER_IS_ENABLED>
52{
53    type Target = SpliceIoCtx<R, W>;
54
55    fn deref(&self) -> &Self::Target {
56        &self.ctx
57    }
58}
59
60impl<R, W> SpliceIo<R, W, RATE_LIMITER_DISABLED> {
61    #[must_use]
62    /// Create a new `SpliceIo` instance with ctx and pinned `R` / `W`.
63    pub fn new(ctx: SpliceIoCtx<R, W>) -> Self {
64        SpliceIo {
65            ctx: CachePadded::new(ctx),
66            r: PhantomData,
67            w: PhantomData,
68            #[cfg(feature = "feat-rate-limit")]
69            rate_limiter: RateLimiter::empty(),
70            state: TransferState::Draining,
71        }
72    }
73
74    #[cfg(feature = "feat-rate-limit")]
75    /// Apply rate limitation during the splice I/O.
76    ///
77    /// See [`RateLimit`] for more details.
78    pub fn with_rate_limit(self, limit: RateLimit) -> SpliceIo<R, W, RATE_LIMITER_ENABLED> {
79        SpliceIo {
80            ctx: self.ctx,
81            r: self.r,
82            w: self.w,
83            rate_limiter: RateLimiter::new(limit),
84            state: self.state,
85        }
86    }
87}
88
89#[derive(Debug)]
90#[pin_project::pin_project(project = TransferStateProj)]
91enum TransferState {
92    /// Draining data from `R` to pipe.
93    Draining,
94
95    #[cfg_attr(not(feature = "feat-rate-limit"), allow(dead_code))]
96    /// Rate limiter is throttling the I/O operations.
97    Throttled {
98        #[cfg(feature = "feat-rate-limit")]
99        #[pin]
100        sleep: Sleep,
101
102        #[cfg(not(feature = "feat-rate-limit"))]
103        #[doc(hidden)]
104        #[pin]
105        // Make `pin-project` happy.
106        _pinned: PhantomPinned,
107    },
108
109    /// Pumping data from pipe to `W`.
110    Pumping,
111
112    /// Flushing buffered data of `W`.
113    Flushing,
114
115    /// Transfer is finished, `W` is shutting down.
116    Terminating,
117
118    /// An error occurred during the transfer.
119    Faulted { error: Option<io::Error> },
120
121    /// Transfer is finished.
122    Finished,
123}
124
125impl<R, W, const RATE_LIMITER_IS_ENABLED: bool> SpliceIo<R, W, RATE_LIMITER_IS_ENABLED>
126where
127    R: AsyncReadFd,
128    W: AsyncWriteFd,
129{
130    /// Performs zero-copy data transfer from reader `R` to writer `W` using the
131    /// splice syscall.
132    ///
133    /// This is a convenient `async fn` version of
134    /// [`SpliceIo::poll_execute`].
135    pub async fn execute(self, r: &mut R, w: &mut W) -> TrafficResult
136    where
137        R: Unpin,
138        W: Unpin,
139    {
140        let mut this = pin!(self);
141        let mut r = Pin::new(r);
142        let mut w = Pin::new(w);
143
144        let error = poll_fn(|cx| this.as_mut().poll_execute(cx, r.as_mut(), w.as_mut()))
145            .await
146            .err();
147
148        this.ctx.traffic_client_tx(error)
149    }
150
151    #[cfg_attr(
152        any(
153            feature = "feat-tracing-trace",
154            all(debug_assertions, feature = "feat-tracing")
155        ),
156        tracing::instrument(level = "TRACE", skip(self, cx, r, w), ret)
157    )]
158    #[allow(clippy::too_many_lines)]
159    /// Performs zero-copy data transfer from reader `R` to writer `W` using the
160    /// splice syscall.
161    ///
162    /// This is the `poll`-based asynchronous version.
163    ///
164    /// # Notes
165    ///
166    /// This is an advanced API that should only be used if you fully understand
167    /// its behavior. When using this API:
168    ///
169    /// - The [`SpliceIo`] instance MUST NOT be reused after completion.
170    /// - The caller MAY manually extracts [`TrafficResult`] from the context.
171    pub fn poll_execute(
172        mut self: Pin<&mut Self>,
173        cx: &mut Context<'_>,
174        mut r: Pin<&mut R>,
175        mut w: Pin<&mut W>,
176    ) -> Poll<io::Result<()>> {
177        macro_rules! ready_or_cleanup {
178            ($e:expr, $state:expr) => {
179                match $e {
180                    Poll::Ready(Ok(t)) => t,
181                    Poll::Ready(Err(e)) => {
182                        $state.set(TransferState::Faulted { error: Some(e) });
183                        continue;
184                    }
185                    Poll::Pending => {
186                        break Poll::Pending;
187                    }
188                }
189            };
190        }
191
192        loop {
193            crate::enter_tracing_span!(
194                "loop",
195                ctx = ?self.ctx,
196                state = ?self.state,
197            );
198
199            let mut this = self.as_mut().project();
200
201            match this.state.as_mut().project() {
202                TransferStateProj::Draining => {
203                    #[cfg(feature = "feat-rate-limit")]
204                    let ideal_len = this.rate_limiter.ideal_len(this.ctx.pipe_size());
205
206                    #[cfg(not(feature = "feat-rate-limit"))]
207                    let ideal_len = None;
208
209                    match ready_or_cleanup!(
210                        this.ctx.poll_splice_drain(cx, r.as_mut(), ideal_len),
211                        this.state.as_mut()
212                    ) {
213                        Drained::Some(_drained) => {
214                            #[cfg(feature = "feat-rate-limit")]
215                            {
216                                #[allow(clippy::used_underscore_binding)]
217                                match this.rate_limiter.check(_drained) {
218                                    RateLimitResult::Accepted => {}
219                                    RateLimitResult::Throttled { now, dur } => {
220                                        this.state.as_mut().set(TransferState::Throttled {
221                                            sleep: tokio::time::sleep_until(now + dur),
222                                        });
223                                        continue;
224                                    }
225                                }
226                            }
227                        }
228                        Drained::Done => {}
229                    }
230
231                    this.state.set(TransferState::Pumping);
232                }
233                #[cfg(feature = "feat-rate-limit")]
234                TransferStateProj::Throttled { sleep } => {
235                    use std::future::Future;
236
237                    ready!(sleep.poll(cx));
238
239                    // After throttled, we shall continue to pump data from pipe to `W`.
240                    this.state.set(TransferState::Pumping);
241                }
242                #[cfg(not(feature = "feat-rate-limit"))]
243                TransferStateProj::Throttled { _pinned } => {
244                    // Actually, this branch should never be reached.
245                    this.state.set(TransferState::Pumping);
246                }
247                TransferStateProj::Pumping => {
248                    ready_or_cleanup!(
249                        this.ctx.poll_splice_pump(cx, w.as_mut()),
250                        this.state.as_mut()
251                    );
252
253                    if this.ctx.finished() {
254                        // All done, flush and shutdown `W`.
255                        this.state.set(TransferState::Terminating);
256                    } else {
257                        // Flush `W` after pumping data.
258                        this.state.set(TransferState::Flushing);
259                    }
260                }
261                TransferStateProj::Flushing => {
262                    ready_or_cleanup!(w.as_mut().poll_flush(cx), this.state.as_mut());
263
264                    this.state.set(TransferState::Draining);
265                }
266                TransferStateProj::Terminating => {
267                    ready_or_cleanup!(w.as_mut().poll_shutdown(cx), this.state.as_mut());
268
269                    this.state.set(TransferState::Finished);
270                }
271                TransferStateProj::Faulted { error } => {
272                    if error.is_some() {
273                        // Best effort to shutdown the writer.
274                        ready!(w.as_mut().poll_shutdown(cx))?;
275                    } else {
276                        #[cfg(feature = "feat-nightly")]
277                        std::hint::cold_path();
278                    }
279
280                    let Some(error) = error.take() else {
281                        #[cfg(feature = "feat-nightly")]
282                        std::hint::cold_path();
283
284                        break Poll::Ready(Err(io::Error::new(
285                            io::ErrorKind::Other,
286                            "`poll_execute()` called after error returned",
287                        )));
288                    };
289
290                    break Poll::Ready(Err(error));
291                }
292                TransferStateProj::Finished => {
293                    break Poll::Ready(Ok(()));
294                }
295            }
296        }
297    }
298}
299
300#[pin_project::pin_project]
301#[derive(Debug)]
302/// Bidirectional splice I/O, combining two `SpliceIo` instances.
303pub struct SpliceBidiIo<
304    SL,
305    SR,
306    const SL_RATE_LIMITER_IS_ENABLED: bool,
307    const SR_RATE_LIMITER_IS_ENABLED: bool,
308> {
309    #[pin]
310    /// Splice I/O instance, from `SL` to `SR`.
311    pub io_sl2sr: SpliceIo<SL, SR, SL_RATE_LIMITER_IS_ENABLED>,
312
313    #[pin]
314    /// Splice I/O instance, from `SR` to `SL`.
315    pub io_sr2sl: SpliceIo<SR, SL, SR_RATE_LIMITER_IS_ENABLED>,
316}
317
318impl<SL, SR, const SL_RATE_LIMITER_IS_ENABLED: bool, const SR_RATE_LIMITER_IS_ENABLED: bool>
319    SpliceBidiIo<SL, SR, SL_RATE_LIMITER_IS_ENABLED, SR_RATE_LIMITER_IS_ENABLED>
320where
321    SL: AsyncReadFd + AsyncWriteFd + IsNotFile,
322    SR: AsyncReadFd + AsyncWriteFd + IsNotFile,
323{
324    /// Performs zero-copy data transfer between `SL` and `SR` using the
325    /// splice syscall.
326    ///
327    /// This is a convenient `async fn` version of
328    /// [`SpliceBidiIo::poll_execute`].
329    pub async fn execute(self, sl: &mut SL, sr: &mut SR) -> TrafficResult
330    where
331        SL: Unpin,
332        SR: Unpin,
333    {
334        let mut this = pin!(self);
335        let mut sl = Pin::new(sl);
336        let mut sr = Pin::new(sr);
337
338        let error = poll_fn(|cx| this.as_mut().poll_execute(cx, sl.as_mut(), sr.as_mut()))
339            .await
340            .err();
341
342        // After copy done, we can return the traffic result.
343        this.io_sl2sr
344            .ctx
345            .traffic_client_tx(error)
346            .merge(this.io_sr2sl.ctx.traffic_client_rx(None))
347    }
348
349    #[cfg_attr(
350        any(
351            feature = "feat-tracing-trace",
352            all(debug_assertions, feature = "feat-tracing")
353        ),
354        tracing::instrument(
355            level = "TRACE",
356            name = "SpliceBidiIo::poll_execute",
357            skip(self, cx, sl, sr),
358            ret
359        )
360    )]
361    /// Performs zero-copy data transfer between `SL` and `SR` using the
362    /// splice syscall.
363    ///
364    /// This is the `poll`-based asynchronous version.
365    ///
366    /// # Notes
367    ///
368    /// This is an advanced API that should only be used if you fully understand
369    /// its behavior. When using this API:
370    ///
371    /// - The [`SpliceBidiIo`] instance MUST NOT be reused after completion.
372    /// - The caller MAY manually extracts [`TrafficResult`] from the context.
373    pub fn poll_execute(
374        self: Pin<&mut Self>,
375        cx: &mut Context<'_>,
376        mut sl: Pin<&mut SL>,
377        mut sr: Pin<&mut SR>,
378    ) -> Poll<io::Result<()>> {
379        let mut this = self.project();
380
381        let io_sl2sr_ret = this
382            .io_sl2sr
383            .as_mut()
384            .poll_execute(cx, sl.as_mut(), sr.as_mut());
385        let io_sr2sl_ret = this
386            .io_sr2sl
387            .as_mut()
388            .poll_execute(cx, sr.as_mut(), sl.as_mut());
389
390        #[cfg(not(feature = "feat-brutal-shutdown"))]
391        {
392            match (io_sl2sr_ret, io_sr2sl_ret) {
393                (Poll::Pending, _) | (_, Poll::Pending) => Poll::Pending,
394                (Poll::Ready(Ok(())), Poll::Ready(Ok(()))) => Poll::Ready(Ok(())),
395                (Poll::Ready(Err(e)), _) | (_, Poll::Ready(Err(e))) => Poll::Ready(Err(e)),
396            }
397        }
398
399        #[cfg(feature = "feat-brutal-shutdown")]
400        {
401            match (io_sl2sr_ret, io_sr2sl_ret) {
402                (Poll::Pending, Poll::Pending) => Poll::Pending,
403                (Poll::Ready(Err(e)), _) | (_, Poll::Ready(Err(e))) => Poll::Ready(Err(e)),
404                // Once received `FIN`, close the other side immediately.
405                (Poll::Ready(Ok(())), _) | (_, Poll::Ready(Ok(()))) => Poll::Ready(Ok(())),
406            }
407        }
408    }
409}
410
411// === traits ===
412
413/// Marker trait: indicate a file.
414///
415/// Since the compiler complains *conflicting implementations* when we try to
416/// implement `IsFile` for `T: ops::Deref<U>` when U: `IsFile`, you have to
417/// implement this marker trait for your wrapper type over a file.
418pub trait IsFile {}
419
420impl<T> IsFile for &mut T where T: IsFile {}
421impl<T> IsFile for Pin<&mut T> where T: IsFile {}
422
423/// Marker trait: indicate not a file.
424///
425/// We have to introduce this because Rust does not allow the syntax `!IsFile`
426/// (at least only limited to some builtin marker traits like `Send`),
427pub trait IsNotFile {}
428
429impl<T> IsNotFile for &mut T where T: IsNotFile {}
430impl<T> IsNotFile for Pin<&mut T> where T: IsNotFile {}
431
432/// Marker trait: indicates an async-readable file descriptor.
433///
434/// This trait extends both `AsyncRead` and `AsFd`, providing the necessary
435/// methods for async reading operations with splice.
436pub trait AsyncReadFd: AsyncRead + AsFd {
437    #[doc(hidden)]
438    fn poll_read_ready(&self, _cx: &mut Context<'_>) -> Poll<io::Result<()>>;
439
440    #[doc(hidden)]
441    fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R>;
442}
443
444impl<T: AsyncReadFd + Unpin> AsyncReadFd for &mut T {
445    #[inline]
446    fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
447        (**self).poll_read_ready(cx)
448    }
449
450    #[inline]
451    fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
452        (**self).try_io_read(f)
453    }
454}
455
456/// Marker trait: indicates an async-writable file descriptor.
457///
458/// This trait extends both `AsyncWrite` and `AsFd`, providing the necessary
459/// methods for async writing operations with splice.
460pub trait AsyncWriteFd: AsyncWrite + AsFd {
461    #[doc(hidden)]
462    fn poll_write_ready(&self, _cx: &mut Context<'_>) -> Poll<io::Result<()>>;
463
464    #[doc(hidden)]
465    fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R>;
466}
467
468impl<T: AsyncWriteFd + Unpin> AsyncWriteFd for &mut T {
469    #[inline]
470    fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
471        (**self).poll_write_ready(cx)
472    }
473
474    #[inline]
475    fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
476        (**self).try_io_write(f)
477    }
478}
479
480macro_rules! impl_async_fd {
481    ($($ty:ty),+) => {
482        $(
483            impl AsyncReadFd for $ty {
484                #[inline]
485                fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
486                    self.poll_read_ready(cx)
487                }
488
489                #[inline]
490                fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
491                    self.try_io(Interest::READABLE, f)
492                }
493            }
494
495            impl AsyncWriteFd for $ty {
496                #[inline]
497                fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
498                    self.poll_write_ready(cx)
499                }
500
501                #[inline]
502                fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
503                    self.try_io(Interest::WRITABLE, f)
504                }
505            }
506
507            impl IsNotFile for $ty {}
508        )+
509    };
510    (FILE: $($ty:ty),+) => {
511        $(
512            impl AsyncReadFd for $ty {
513                #[inline]
514                fn poll_read_ready(&self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
515                    Poll::Ready(Ok(()))
516                }
517
518                #[inline]
519                fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
520                    f()
521                }
522            }
523
524            impl AsyncWriteFd for $ty {
525                #[inline]
526                fn poll_write_ready(&self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
527                    Poll::Ready(Ok(()))
528                }
529
530                #[inline]
531                fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
532                    f()
533                }
534            }
535
536            impl IsFile for $ty {}
537        )+
538    };
539}
540
541impl_async_fd!(TcpStream, UnixStream);
542impl_async_fd!(FILE: File);