1#[macro_use]
9extern crate log;
10
11use std::fmt::{Display, Formatter};
12use std::path::Path;
13use std::sync::{Arc, Mutex};
14use std::thread;
15
16use vhost::vhost_user::{BackendListener, BackendReqHandler, Error as VhostUserError, Listener};
17use vm_memory::mmap::NewBitmap;
18use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap};
19
20use self::handler::VhostUserHandler;
21
22mod backend;
23pub use self::backend::{VhostUserBackend, VhostUserBackendMut};
24
25mod event_loop;
26pub use self::event_loop::VringEpollHandler;
27
28mod handler;
29pub use self::handler::VhostUserHandlerError;
30
31pub mod bitmap;
32use crate::bitmap::BitmapReplace;
33
34mod vring;
35pub use self::vring::{
36 VringMutex, VringRwLock, VringState, VringStateGuard, VringStateMutGuard, VringT,
37};
38
39#[cfg(all(feature = "postcopy", feature = "xen"))]
43compile_error!("Both `postcopy` and `xen` features can not be enabled at the same time.");
44
45type GM<B> = GuestMemoryAtomic<GuestMemoryMmap<B>>;
47
48#[derive(Debug)]
49pub enum Error {
51 NewVhostUserHandler(VhostUserHandlerError),
53 CreateBackendListener(VhostUserError),
55 CreateBackendReqHandler(VhostUserError),
57 CreateVhostUserListener(VhostUserError),
59 StartDaemon(std::io::Error),
61 WaitDaemon(std::boxed::Box<dyn std::any::Any + std::marker::Send>),
63 HandleRequest(VhostUserError),
65}
66
67impl Display for Error {
68 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
69 match self {
70 Error::NewVhostUserHandler(e) => write!(f, "cannot create vhost user handler: {}", e),
71 Error::CreateBackendListener(e) => write!(f, "cannot create backend listener: {}", e),
72 Error::CreateBackendReqHandler(e) => {
73 write!(f, "cannot create backend req handler: {}", e)
74 }
75 Error::CreateVhostUserListener(e) => {
76 write!(f, "cannot create vhost-user listener: {}", e)
77 }
78 Error::StartDaemon(e) => write!(f, "failed to start daemon: {}", e),
79 Error::WaitDaemon(_e) => write!(f, "failed to wait for daemon exit"),
80 Error::HandleRequest(e) => write!(f, "failed to handle request: {}", e),
81 }
82 }
83}
84
85pub type Result<T> = std::result::Result<T, Error>;
87
88pub struct VhostUserDaemon<T: VhostUserBackend> {
93 name: String,
94 handler: Arc<Mutex<VhostUserHandler<T>>>,
95 main_thread: Option<thread::JoinHandle<Result<()>>>,
96}
97
98impl<T> VhostUserDaemon<T>
99where
100 T: VhostUserBackend + Clone + 'static,
101 T::Bitmap: BitmapReplace + NewBitmap + Clone + Send + Sync,
102 T::Vring: Clone + Send + Sync,
103{
104 pub fn new(
110 name: String,
111 backend: T,
112 atomic_mem: GuestMemoryAtomic<GuestMemoryMmap<T::Bitmap>>,
113 ) -> Result<Self> {
114 let handler = Arc::new(Mutex::new(
115 VhostUserHandler::new(backend, atomic_mem).map_err(Error::NewVhostUserHandler)?,
116 ));
117
118 Ok(VhostUserDaemon {
119 name,
120 handler,
121 main_thread: None,
122 })
123 }
124
125 fn start_daemon(
132 &mut self,
133 mut handler: BackendReqHandler<Mutex<VhostUserHandler<T>>>,
134 ) -> Result<()> {
135 let handle = thread::Builder::new()
136 .name(self.name.clone())
137 .spawn(move || loop {
138 handler.handle_request().map_err(Error::HandleRequest)?;
139 })
140 .map_err(Error::StartDaemon)?;
141
142 self.main_thread = Some(handle);
143
144 Ok(())
145 }
146
147 pub fn start_client(&mut self, socket_path: &str) -> Result<()> {
152 let backend_handler = BackendReqHandler::connect(socket_path, self.handler.clone())
153 .map_err(Error::CreateBackendReqHandler)?;
154 self.start_daemon(backend_handler)
155 }
156
157 pub fn start(&mut self, listener: Listener) -> Result<()> {
168 let mut backend_listener = BackendListener::new(listener, self.handler.clone())
169 .map_err(Error::CreateBackendListener)?;
170 let backend_handler = self.accept(&mut backend_listener)?;
171 self.start_daemon(backend_handler)
172 }
173
174 fn accept(
175 &self,
176 backend_listener: &mut BackendListener<Mutex<VhostUserHandler<T>>>,
177 ) -> Result<BackendReqHandler<Mutex<VhostUserHandler<T>>>> {
178 loop {
179 match backend_listener.accept() {
180 Err(e) => return Err(Error::CreateBackendListener(e)),
181 Ok(Some(v)) => return Ok(v),
182 Ok(None) => continue,
183 }
184 }
185 }
186
187 pub fn wait(&mut self) -> Result<()> {
192 if let Some(handle) = self.main_thread.take() {
193 match handle.join().map_err(Error::WaitDaemon)? {
194 Ok(()) => Ok(()),
195 Err(Error::HandleRequest(VhostUserError::SocketBroken(_))) => Ok(()),
196 Err(e) => Err(e),
197 }
198 } else {
199 Ok(())
200 }
201 }
202
203 pub fn serve<P: AsRef<Path>>(&mut self, socket: P) -> Result<()> {
218 let listener = Listener::new(socket, true).map_err(Error::CreateVhostUserListener)?;
219
220 self.start(listener)?;
221 let result = self.wait();
222
223 self.handler.lock().unwrap().send_exit_event();
225
226 match &result {
230 Err(e) => match e {
231 Error::HandleRequest(VhostUserError::Disconnected) => Ok(()),
232 Error::HandleRequest(VhostUserError::PartialMessage) => Ok(()),
233 _ => result,
234 },
235 _ => result,
236 }
237 }
238
239 pub fn get_epoll_handlers(&self) -> Vec<Arc<VringEpollHandler<T>>> {
244 self.handler.lock().unwrap().get_epoll_handlers()
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::backend::tests::MockVhostBackend;
252 use super::*;
253 use libc::EAGAIN;
254 use std::os::unix::net::{UnixListener, UnixStream};
255 use std::sync::Barrier;
256 use std::time::Duration;
257 use vm_memory::{GuestAddress, GuestMemoryAtomic, GuestMemoryMmap};
258
259 #[test]
260 fn test_new_daemon() {
261 let mem = GuestMemoryAtomic::new(
262 GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(),
263 );
264 let backend = Arc::new(Mutex::new(MockVhostBackend::new()));
265 let mut daemon = VhostUserDaemon::new("test".to_owned(), backend, mem).unwrap();
266
267 let handlers = daemon.get_epoll_handlers();
268 assert_eq!(handlers.len(), 2);
269
270 let barrier = Arc::new(Barrier::new(2));
271 let tmpdir = tempfile::tempdir().unwrap();
272 let path = tmpdir.path().join("socket");
273
274 thread::scope(|s| {
275 s.spawn(|| {
276 barrier.wait();
277 let socket = UnixStream::connect(&path).unwrap();
278 barrier.wait();
279 drop(socket)
280 });
281
282 let listener = Listener::new(&path, false).unwrap();
283 barrier.wait();
284 daemon.start(listener).unwrap();
285 barrier.wait();
286 daemon.wait().unwrap_err();
288 daemon.wait().unwrap();
289 });
290 }
291
292 #[test]
293 fn test_new_daemon_client() {
294 let mem = GuestMemoryAtomic::new(
295 GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(),
296 );
297 let backend = Arc::new(Mutex::new(MockVhostBackend::new()));
298 let mut daemon = VhostUserDaemon::new("test".to_owned(), backend, mem).unwrap();
299
300 let handlers = daemon.get_epoll_handlers();
301 assert_eq!(handlers.len(), 2);
302
303 let barrier = Arc::new(Barrier::new(2));
304 let tmpdir = tempfile::tempdir().unwrap();
305 let path = tmpdir.path().join("socket");
306
307 thread::scope(|s| {
308 s.spawn(|| {
309 let listener = UnixListener::bind(&path).unwrap();
310 barrier.wait();
311 let (stream, _) = listener.accept().unwrap();
312 barrier.wait();
313 drop(stream)
314 });
315
316 barrier.wait();
317 daemon
318 .start_client(path.as_path().to_str().unwrap())
319 .unwrap();
320 barrier.wait();
321 daemon.wait().unwrap_err();
323 daemon.wait().unwrap();
324 });
325 }
326
327 #[test]
328 fn test_daemon_serve() {
329 let mem = GuestMemoryAtomic::new(
330 GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(),
331 );
332 let backend = Arc::new(Mutex::new(MockVhostBackend::new()));
333 let mut daemon = VhostUserDaemon::new("test".to_owned(), backend.clone(), mem).unwrap();
334 let tmpdir = tempfile::tempdir().unwrap();
335 let socket_path = tmpdir.path().join("socket");
336
337 thread::scope(|s| {
338 s.spawn(|| {
339 let _ = daemon.serve(&socket_path);
340 });
341
342 while !socket_path.exists() {
345 thread::sleep(Duration::from_millis(10));
346 }
347
348 for thread_id in 0..backend.queues_per_thread().len() {
350 let fd = backend.exit_event(thread_id).unwrap();
351 assert_eq!(
353 fd.read().unwrap_err().raw_os_error().unwrap(),
354 EAGAIN,
355 "exit event should not have been raised yet!"
356 );
357 }
358
359 let socket = UnixStream::connect(&socket_path).unwrap();
360 drop(socket);
362 });
363
364 let backend = backend.lock().unwrap();
366 for thread_id in 0..backend.queues_per_thread().len() {
367 let fd = backend.exit_event(thread_id).unwrap();
368 assert!(fd.read().is_ok(), "No exit event was raised!");
369 }
370 }
371}