tokio_vsock/
stream.rs

1/*
2 * Tokio Reference TCP Implementation
3 * Copyright (c) 2019 Tokio Contributors
4 *
5 * Permission is hereby granted, free of charge, to any
6 * person obtaining a copy of this software and associated
7 * documentation files (the "Software"), to deal in the
8 * Software without restriction, including without
9 * limitation the rights to use, copy, modify, merge,
10 * publish, distribute, sublicense, and/or sell copies of
11 * the Software, and to permit persons to whom the Software
12 * is furnished to do so, subject to the following
13 * conditions:
14 *
15 * The above copyright notice and this permission notice
16 * shall be included in all copies or substantial portions
17 * of the Software.
18 *
19 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
20 * ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
21 * TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
22 * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
23 * SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
24 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
25 * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
26 * IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
27 * DEALINGS IN THE SOFTWARE.
28 */
29
30/*
31 * Copyright 2019 fsyncd, Berlin, Germany.
32 *
33 * Licensed under the Apache License, Version 2.0 (the "License");
34 * you may not use this file except in compliance with the License.
35 * You may obtain a copy of the License at
36 *
37 *     http://www.apache.org/licenses/LICENSE-2.0
38 *
39 * Unless required by applicable law or agreed to in writing, software
40 * distributed under the License is distributed on an "AS IS" BASIS,
41 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42 * See the License for the specific language governing permissions and
43 * limitations under the License.
44 */
45
46use std::io::{Error, Read, Result, Write};
47use std::net::Shutdown;
48use std::os::fd::{AsFd, BorrowedFd};
49use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
50
51use crate::split::{split_owned, OwnedReadHalf, OwnedWriteHalf, ReadHalf, WriteHalf};
52use crate::VsockAddr;
53use futures::ready;
54use libc::*;
55use std::mem::{self, size_of};
56use std::pin::Pin;
57use std::task::{Context, Poll};
58use tokio::io::unix::AsyncFd;
59use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
60
61/// An I/O object representing a Virtio socket connected to a remote endpoint.
62#[derive(Debug)]
63pub struct VsockStream {
64    inner: AsyncFd<vsock::VsockStream>,
65}
66
67impl VsockStream {
68    pub fn new(connected: vsock::VsockStream) -> Result<Self> {
69        connected.set_nonblocking(true)?;
70        Ok(Self {
71            inner: AsyncFd::new(connected)?,
72        })
73    }
74
75    /// Open a connection to a remote host.
76    pub async fn connect(addr: VsockAddr) -> Result<Self> {
77        let socket = unsafe { socket(AF_VSOCK, SOCK_STREAM, 0) };
78        if socket < 0 {
79            return Err(Error::last_os_error());
80        }
81
82        if unsafe { fcntl(socket, F_SETFL, O_NONBLOCK | O_CLOEXEC) } < 0 {
83            let _ = unsafe { close(socket) };
84            return Err(Error::last_os_error());
85        }
86
87        if unsafe {
88            connect(
89                socket,
90                &addr as *const _ as *const sockaddr,
91                size_of::<sockaddr_vm>() as socklen_t,
92            )
93        } < 0
94        {
95            let err = Error::last_os_error();
96            if let Some(os_err) = err.raw_os_error() {
97                // Connect hasn't finished, that's fine.
98                if os_err != EINPROGRESS {
99                    // Close the socket if we hit an error, ignoring the error
100                    // from closing since we can't pass back two errors.
101                    let _ = unsafe { close(socket) };
102                    return Err(err);
103                }
104            }
105        }
106
107        loop {
108            let stream = unsafe { vsock::VsockStream::from_raw_fd(socket) };
109            let stream = Self::new(stream)?;
110            let mut guard = stream.inner.writable().await?;
111
112            // Checks if the connection failed or not
113            let conn_check = guard.try_io(|fd| {
114                let mut sock_err: c_int = 0;
115                let mut sock_err_len: socklen_t = size_of::<c_int>() as socklen_t;
116                let err = unsafe {
117                    getsockopt(
118                        fd.as_raw_fd(),
119                        SOL_SOCKET,
120                        SO_ERROR,
121                        &mut sock_err as *mut _ as *mut c_void,
122                        &mut sock_err_len as *mut socklen_t,
123                    )
124                };
125                if err < 0 {
126                    return Err(Error::last_os_error());
127                }
128                if sock_err == 0 {
129                    Ok(())
130                } else {
131                    Err(Error::from_raw_os_error(sock_err))
132                }
133            });
134
135            match conn_check {
136                Ok(Ok(_)) => return Ok(stream),
137                Ok(Err(err)) => return Err(err),
138                Err(_would_block) => continue,
139            }
140        }
141    }
142
143    /// The local address that this socket is bound to.
144    pub fn local_addr(&self) -> Result<VsockAddr> {
145        self.inner.get_ref().local_addr()
146    }
147
148    /// The remote address that this socket is connected to.
149    pub fn peer_addr(&self) -> Result<VsockAddr> {
150        self.inner.get_ref().peer_addr()
151    }
152
153    /// Shuts down the read, write, or both halves of this connection.
154    pub fn shutdown(&self, how: Shutdown) -> Result<()> {
155        self.inner.get_ref().shutdown(how)
156    }
157
158    /// Splits a single value implementing `AsyncRead + AsyncWrite` into separate
159    /// `AsyncRead` and `AsyncWrite` handles.
160    pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
161        crate::split::split(self)
162    }
163
164    /// Splits a single value implementing `AsyncRead + AsyncWrite` into separate
165    /// `AsyncRead` and `AsyncWrite` handles.
166    ///
167    /// To restore this read/write object from its `OwnedReadHalf` and
168    /// `OwnedWriteHalf` use [`unsplit`](OwnedReadHalf::unsplit()).
169    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
170        split_owned(self)
171    }
172
173    pub(crate) fn poll_write_priv(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
174        loop {
175            let mut guard = ready!(self.inner.poll_write_ready(cx))?;
176
177            match guard.try_io(|inner| inner.get_ref().write(buf)) {
178                Ok(Ok(n)) => return Ok(n).into(),
179                Ok(Err(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => continue,
180                Ok(Err(e)) => return Err(e).into(),
181                Err(_would_block) => continue,
182            }
183        }
184    }
185
186    pub(crate) fn poll_read_priv(
187        &self,
188        cx: &mut Context<'_>,
189        buf: &mut ReadBuf<'_>,
190    ) -> Poll<Result<()>> {
191        let b;
192        unsafe {
193            b = &mut *(buf.unfilled_mut() as *mut [mem::MaybeUninit<u8>] as *mut [u8]);
194        };
195
196        loop {
197            let mut guard = ready!(self.inner.poll_read_ready(cx))?;
198
199            match guard.try_io(|inner| inner.get_ref().read(b)) {
200                Ok(Ok(n)) => {
201                    unsafe {
202                        buf.assume_init(n);
203                    }
204                    buf.advance(n);
205                    return Ok(()).into();
206                }
207                Ok(Err(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => continue,
208                Ok(Err(e)) => return Err(e).into(),
209                Err(_would_block) => {
210                    continue;
211                }
212            }
213        }
214    }
215}
216
217impl AsFd for VsockStream {
218    fn as_fd(&self) -> BorrowedFd<'_> {
219        self.inner.get_ref().as_fd()
220    }
221}
222
223impl AsRawFd for VsockStream {
224    fn as_raw_fd(&self) -> RawFd {
225        self.inner.get_ref().as_raw_fd()
226    }
227}
228
229impl IntoRawFd for VsockStream {
230    fn into_raw_fd(self) -> RawFd {
231        let fd = self.inner.get_ref().as_raw_fd();
232        mem::forget(self);
233        fd
234    }
235}
236
237impl Write for VsockStream {
238    fn write(&mut self, buf: &[u8]) -> Result<usize> {
239        self.inner.get_ref().write(buf)
240    }
241
242    fn flush(&mut self) -> Result<()> {
243        Ok(())
244    }
245}
246
247impl Read for VsockStream {
248    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
249        self.inner.get_ref().read(buf)
250    }
251}
252
253impl AsyncWrite for VsockStream {
254    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
255        self.poll_write_priv(cx, buf)
256    }
257
258    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
259        Poll::Ready(Ok(()))
260    }
261
262    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
263        self.shutdown(std::net::Shutdown::Write)?;
264        Poll::Ready(Ok(()))
265    }
266}
267
268impl AsyncRead for VsockStream {
269    fn poll_read(
270        self: Pin<&mut Self>,
271        cx: &mut Context<'_>,
272        buf: &mut ReadBuf<'_>,
273    ) -> Poll<Result<()>> {
274        self.poll_read_priv(cx, buf)
275    }
276}