1use std::io::{Error, Read, Result, Write};
47use std::net::Shutdown;
48use std::os::fd::{AsFd, BorrowedFd};
49use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
50
51use crate::split::{split_owned, OwnedReadHalf, OwnedWriteHalf, ReadHalf, WriteHalf};
52use crate::VsockAddr;
53use futures::ready;
54use libc::*;
55use std::mem::{self, size_of};
56use std::pin::Pin;
57use std::task::{Context, Poll};
58use tokio::io::unix::AsyncFd;
59use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
60
61#[derive(Debug)]
63pub struct VsockStream {
64 inner: AsyncFd<vsock::VsockStream>,
65}
66
67impl VsockStream {
68 pub fn new(connected: vsock::VsockStream) -> Result<Self> {
69 connected.set_nonblocking(true)?;
70 Ok(Self {
71 inner: AsyncFd::new(connected)?,
72 })
73 }
74
75 pub async fn connect(addr: VsockAddr) -> Result<Self> {
77 let socket = unsafe { socket(AF_VSOCK, SOCK_STREAM, 0) };
78 if socket < 0 {
79 return Err(Error::last_os_error());
80 }
81
82 if unsafe { fcntl(socket, F_SETFL, O_NONBLOCK | O_CLOEXEC) } < 0 {
83 let _ = unsafe { close(socket) };
84 return Err(Error::last_os_error());
85 }
86
87 if unsafe {
88 connect(
89 socket,
90 &addr as *const _ as *const sockaddr,
91 size_of::<sockaddr_vm>() as socklen_t,
92 )
93 } < 0
94 {
95 let err = Error::last_os_error();
96 if let Some(os_err) = err.raw_os_error() {
97 if os_err != EINPROGRESS {
99 let _ = unsafe { close(socket) };
102 return Err(err);
103 }
104 }
105 }
106
107 loop {
108 let stream = unsafe { vsock::VsockStream::from_raw_fd(socket) };
109 let stream = Self::new(stream)?;
110 let mut guard = stream.inner.writable().await?;
111
112 let conn_check = guard.try_io(|fd| {
114 let mut sock_err: c_int = 0;
115 let mut sock_err_len: socklen_t = size_of::<c_int>() as socklen_t;
116 let err = unsafe {
117 getsockopt(
118 fd.as_raw_fd(),
119 SOL_SOCKET,
120 SO_ERROR,
121 &mut sock_err as *mut _ as *mut c_void,
122 &mut sock_err_len as *mut socklen_t,
123 )
124 };
125 if err < 0 {
126 return Err(Error::last_os_error());
127 }
128 if sock_err == 0 {
129 Ok(())
130 } else {
131 Err(Error::from_raw_os_error(sock_err))
132 }
133 });
134
135 match conn_check {
136 Ok(Ok(_)) => return Ok(stream),
137 Ok(Err(err)) => return Err(err),
138 Err(_would_block) => continue,
139 }
140 }
141 }
142
143 pub fn local_addr(&self) -> Result<VsockAddr> {
145 self.inner.get_ref().local_addr()
146 }
147
148 pub fn peer_addr(&self) -> Result<VsockAddr> {
150 self.inner.get_ref().peer_addr()
151 }
152
153 pub fn shutdown(&self, how: Shutdown) -> Result<()> {
155 self.inner.get_ref().shutdown(how)
156 }
157
158 pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
161 crate::split::split(self)
162 }
163
164 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
170 split_owned(self)
171 }
172
173 pub(crate) fn poll_write_priv(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
174 loop {
175 let mut guard = ready!(self.inner.poll_write_ready(cx))?;
176
177 match guard.try_io(|inner| inner.get_ref().write(buf)) {
178 Ok(Ok(n)) => return Ok(n).into(),
179 Ok(Err(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => continue,
180 Ok(Err(e)) => return Err(e).into(),
181 Err(_would_block) => continue,
182 }
183 }
184 }
185
186 pub(crate) fn poll_read_priv(
187 &self,
188 cx: &mut Context<'_>,
189 buf: &mut ReadBuf<'_>,
190 ) -> Poll<Result<()>> {
191 let b;
192 unsafe {
193 b = &mut *(buf.unfilled_mut() as *mut [mem::MaybeUninit<u8>] as *mut [u8]);
194 };
195
196 loop {
197 let mut guard = ready!(self.inner.poll_read_ready(cx))?;
198
199 match guard.try_io(|inner| inner.get_ref().read(b)) {
200 Ok(Ok(n)) => {
201 unsafe {
202 buf.assume_init(n);
203 }
204 buf.advance(n);
205 return Ok(()).into();
206 }
207 Ok(Err(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => continue,
208 Ok(Err(e)) => return Err(e).into(),
209 Err(_would_block) => {
210 continue;
211 }
212 }
213 }
214 }
215}
216
217impl AsFd for VsockStream {
218 fn as_fd(&self) -> BorrowedFd<'_> {
219 self.inner.get_ref().as_fd()
220 }
221}
222
223impl AsRawFd for VsockStream {
224 fn as_raw_fd(&self) -> RawFd {
225 self.inner.get_ref().as_raw_fd()
226 }
227}
228
229impl IntoRawFd for VsockStream {
230 fn into_raw_fd(self) -> RawFd {
231 let fd = self.inner.get_ref().as_raw_fd();
232 mem::forget(self);
233 fd
234 }
235}
236
237impl Write for VsockStream {
238 fn write(&mut self, buf: &[u8]) -> Result<usize> {
239 self.inner.get_ref().write(buf)
240 }
241
242 fn flush(&mut self) -> Result<()> {
243 Ok(())
244 }
245}
246
247impl Read for VsockStream {
248 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
249 self.inner.get_ref().read(buf)
250 }
251}
252
253impl AsyncWrite for VsockStream {
254 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
255 self.poll_write_priv(cx, buf)
256 }
257
258 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
259 Poll::Ready(Ok(()))
260 }
261
262 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
263 self.shutdown(std::net::Shutdown::Write)?;
264 Poll::Ready(Ok(()))
265 }
266}
267
268impl AsyncRead for VsockStream {
269 fn poll_read(
270 self: Pin<&mut Self>,
271 cx: &mut Context<'_>,
272 buf: &mut ReadBuf<'_>,
273 ) -> Poll<Result<()>> {
274 self.poll_read_priv(cx, buf)
275 }
276}