use std::error;
use std::fs::File;
use std::io;
use std::os::unix::io::AsRawFd;
use std::sync::Arc;
use std::thread;
use vhost::vhost_user::message::{
VhostUserConfigFlags, VhostUserMemoryRegion, VhostUserProtocolFeatures,
VhostUserSingleMemoryRegion, VhostUserVirtioFeatures, VhostUserVringAddrFlags,
VhostUserVringState,
};
use vhost::vhost_user::{
Backend, Error as VhostUserError, Result as VhostUserResult, VhostUserBackendReqHandlerMut,
};
use virtio_bindings::bindings::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
use virtio_queue::{Error as VirtQueError, QueueT};
use vm_memory::mmap::NewBitmap;
use vm_memory::{GuestAddress, GuestAddressSpace, GuestMemoryMmap, GuestRegionMmap};
use vmm_sys_util::epoll::EventSet;
use super::backend::VhostUserBackend;
use super::event_loop::VringEpollHandler;
use super::event_loop::{VringEpollError, VringEpollResult};
use super::vring::VringT;
use super::GM;
const MAX_MEM_SLOTS: u64 = 509;
#[derive(Debug)]
pub enum VhostUserHandlerError {
CreateVring(VirtQueError),
CreateEpollHandler(VringEpollError),
SpawnVringWorker(io::Error),
MissingMemoryMapping,
}
impl std::fmt::Display for VhostUserHandlerError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
VhostUserHandlerError::CreateVring(e) => {
write!(f, "failed to create vring: {}", e)
}
VhostUserHandlerError::CreateEpollHandler(e) => {
write!(f, "failed to create vring epoll handler: {}", e)
}
VhostUserHandlerError::SpawnVringWorker(e) => {
write!(f, "failed spawning the vring worker: {}", e)
}
VhostUserHandlerError::MissingMemoryMapping => write!(f, "Missing memory mapping"),
}
}
}
impl error::Error for VhostUserHandlerError {}
pub type VhostUserHandlerResult<T> = std::result::Result<T, VhostUserHandlerError>;
struct AddrMapping {
vmm_addr: u64,
size: u64,
gpa_base: u64,
}
pub struct VhostUserHandler<T: VhostUserBackend> {
backend: T,
handlers: Vec<Arc<VringEpollHandler<T>>>,
owned: bool,
features_acked: bool,
acked_features: u64,
acked_protocol_features: u64,
num_queues: usize,
max_queue_size: usize,
queues_per_thread: Vec<u64>,
mappings: Vec<AddrMapping>,
atomic_mem: GM<T::Bitmap>,
vrings: Vec<T::Vring>,
worker_threads: Vec<thread::JoinHandle<VringEpollResult<()>>>,
}
impl<T> VhostUserHandler<T>
where
T: VhostUserBackend + Clone + 'static,
T::Vring: Clone + Send + Sync + 'static,
T::Bitmap: Clone + Send + Sync + 'static,
{
pub(crate) fn new(backend: T, atomic_mem: GM<T::Bitmap>) -> VhostUserHandlerResult<Self> {
let num_queues = backend.num_queues();
let max_queue_size = backend.max_queue_size();
let queues_per_thread = backend.queues_per_thread();
let mut vrings = Vec::new();
for _ in 0..num_queues {
let vring = T::Vring::new(atomic_mem.clone(), max_queue_size as u16)
.map_err(VhostUserHandlerError::CreateVring)?;
vrings.push(vring);
}
let mut handlers = Vec::new();
let mut worker_threads = Vec::new();
for (thread_id, queues_mask) in queues_per_thread.iter().enumerate() {
let mut thread_vrings = Vec::new();
for (index, vring) in vrings.iter().enumerate() {
if (queues_mask >> index) & 1u64 == 1u64 {
thread_vrings.push(vring.clone());
}
}
let handler = Arc::new(
VringEpollHandler::new(backend.clone(), thread_vrings, thread_id)
.map_err(VhostUserHandlerError::CreateEpollHandler)?,
);
let handler2 = handler.clone();
let worker_thread = thread::Builder::new()
.name("vring_worker".to_string())
.spawn(move || handler2.run())
.map_err(VhostUserHandlerError::SpawnVringWorker)?;
handlers.push(handler);
worker_threads.push(worker_thread);
}
Ok(VhostUserHandler {
backend,
handlers,
owned: false,
features_acked: false,
acked_features: 0,
acked_protocol_features: 0,
num_queues,
max_queue_size,
queues_per_thread,
mappings: Vec::new(),
atomic_mem,
vrings,
worker_threads,
})
}
}
impl<T: VhostUserBackend> VhostUserHandler<T> {
pub(crate) fn send_exit_event(&self) {
for handler in self.handlers.iter() {
handler.send_exit_event();
}
}
fn vmm_va_to_gpa(&self, vmm_va: u64) -> VhostUserHandlerResult<u64> {
for mapping in self.mappings.iter() {
if vmm_va >= mapping.vmm_addr && vmm_va < mapping.vmm_addr + mapping.size {
return Ok(vmm_va - mapping.vmm_addr + mapping.gpa_base);
}
}
Err(VhostUserHandlerError::MissingMemoryMapping)
}
}
impl<T> VhostUserHandler<T>
where
T: VhostUserBackend,
{
pub(crate) fn get_epoll_handlers(&self) -> Vec<Arc<VringEpollHandler<T>>> {
self.handlers.clone()
}
fn vring_needs_init(&self, vring: &T::Vring) -> bool {
let vring_state = vring.get_ref();
!vring_state.get_queue().ready() && vring_state.get_kick().is_some()
}
fn initialize_vring(&self, vring: &T::Vring, index: u8) -> VhostUserResult<()> {
assert!(vring.get_ref().get_kick().is_some());
if let Some(fd) = vring.get_ref().get_kick() {
for (thread_index, queues_mask) in self.queues_per_thread.iter().enumerate() {
let shifted_queues_mask = queues_mask >> index;
if shifted_queues_mask & 1u64 == 1u64 {
let evt_idx = queues_mask.count_ones() - shifted_queues_mask.count_ones();
self.handlers[thread_index]
.register_event(fd.as_raw_fd(), EventSet::IN, u64::from(evt_idx))
.map_err(VhostUserError::ReqHandlerError)?;
break;
}
}
}
vring.set_queue_ready(true);
Ok(())
}
fn check_feature(&self, feat: VhostUserVirtioFeatures) -> VhostUserResult<()> {
if self.acked_features & feat.bits() != 0 {
Ok(())
} else {
Err(VhostUserError::InactiveFeature(feat))
}
}
}
impl<T: VhostUserBackend> VhostUserBackendReqHandlerMut for VhostUserHandler<T>
where
T::Bitmap: NewBitmap + Clone,
{
fn set_owner(&mut self) -> VhostUserResult<()> {
if self.owned {
return Err(VhostUserError::InvalidOperation("already claimed"));
}
self.owned = true;
Ok(())
}
fn reset_owner(&mut self) -> VhostUserResult<()> {
self.owned = false;
self.features_acked = false;
self.acked_features = 0;
self.acked_protocol_features = 0;
Ok(())
}
fn get_features(&mut self) -> VhostUserResult<u64> {
Ok(self.backend.features())
}
fn set_features(&mut self, features: u64) -> VhostUserResult<()> {
if (features & !self.backend.features()) != 0 {
return Err(VhostUserError::InvalidParam);
}
self.acked_features = features;
self.features_acked = true;
if self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
for vring in self.vrings.iter_mut() {
vring.set_enabled(true);
}
}
self.backend.acked_features(self.acked_features);
Ok(())
}
fn set_mem_table(
&mut self,
ctx: &[VhostUserMemoryRegion],
files: Vec<File>,
) -> VhostUserResult<()> {
let mut regions = Vec::new();
let mut mappings: Vec<AddrMapping> = Vec::new();
for (region, file) in ctx.iter().zip(files) {
regions.push(
GuestRegionMmap::new(
region.mmap_region(file)?,
GuestAddress(region.guest_phys_addr),
)
.map_err(|e| {
VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
})?,
);
mappings.push(AddrMapping {
vmm_addr: region.user_addr,
size: region.memory_size,
gpa_base: region.guest_phys_addr,
});
}
let mem = GuestMemoryMmap::from_regions(regions).map_err(|e| {
VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
})?;
self.atomic_mem.lock().unwrap().replace(mem);
self.backend
.update_memory(self.atomic_mem.clone())
.map_err(|e| {
VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
})?;
self.mappings = mappings;
Ok(())
}
fn set_vring_num(&mut self, index: u32, num: u32) -> VhostUserResult<()> {
let vring = self
.vrings
.get(index as usize)
.ok_or_else(|| VhostUserError::InvalidParam)?;
if num == 0 || num as usize > self.max_queue_size {
return Err(VhostUserError::InvalidParam);
}
vring.set_queue_size(num as u16);
Ok(())
}
fn set_vring_addr(
&mut self,
index: u32,
_flags: VhostUserVringAddrFlags,
descriptor: u64,
used: u64,
available: u64,
_log: u64,
) -> VhostUserResult<()> {
let vring = self
.vrings
.get(index as usize)
.ok_or_else(|| VhostUserError::InvalidParam)?;
if !self.mappings.is_empty() {
let desc_table = self.vmm_va_to_gpa(descriptor).map_err(|e| {
VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
})?;
let avail_ring = self.vmm_va_to_gpa(available).map_err(|e| {
VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
})?;
let used_ring = self.vmm_va_to_gpa(used).map_err(|e| {
VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
})?;
vring
.set_queue_info(desc_table, avail_ring, used_ring)
.map_err(|_| VhostUserError::InvalidParam)?;
let idx = vring
.queue_used_idx()
.map_err(|_| VhostUserError::BackendInternalError)?;
vring.set_queue_next_used(idx);
Ok(())
} else {
Err(VhostUserError::InvalidParam)
}
}
fn set_vring_base(&mut self, index: u32, base: u32) -> VhostUserResult<()> {
let vring = self
.vrings
.get(index as usize)
.ok_or_else(|| VhostUserError::InvalidParam)?;
let event_idx: bool = (self.acked_features & (1 << VIRTIO_RING_F_EVENT_IDX)) != 0;
vring.set_queue_next_avail(base as u16);
vring.set_queue_event_idx(event_idx);
self.backend.set_event_idx(event_idx);
Ok(())
}
fn get_vring_base(&mut self, index: u32) -> VhostUserResult<VhostUserVringState> {
let vring = self
.vrings
.get(index as usize)
.ok_or_else(|| VhostUserError::InvalidParam)?;
vring.set_queue_ready(false);
if let Some(fd) = vring.get_ref().get_kick() {
for (thread_index, queues_mask) in self.queues_per_thread.iter().enumerate() {
let shifted_queues_mask = queues_mask >> index;
if shifted_queues_mask & 1u64 == 1u64 {
let evt_idx = queues_mask.count_ones() - shifted_queues_mask.count_ones();
self.handlers[thread_index]
.unregister_event(fd.as_raw_fd(), EventSet::IN, u64::from(evt_idx))
.map_err(VhostUserError::ReqHandlerError)?;
break;
}
}
}
let next_avail = vring.queue_next_avail();
vring.set_kick(None);
vring.set_call(None);
Ok(VhostUserVringState::new(index, u32::from(next_avail)))
}
fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostUserResult<()> {
let vring = self
.vrings
.get(index as usize)
.ok_or_else(|| VhostUserError::InvalidParam)?;
vring.set_kick(file);
if self.vring_needs_init(vring) {
self.initialize_vring(vring, index)?;
}
Ok(())
}
fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostUserResult<()> {
let vring = self
.vrings
.get(index as usize)
.ok_or_else(|| VhostUserError::InvalidParam)?;
vring.set_call(file);
if self.vring_needs_init(vring) {
self.initialize_vring(vring, index)?;
}
Ok(())
}
fn set_vring_err(&mut self, index: u8, file: Option<File>) -> VhostUserResult<()> {
let vring = self
.vrings
.get(index as usize)
.ok_or_else(|| VhostUserError::InvalidParam)?;
vring.set_err(file);
Ok(())
}
fn get_protocol_features(&mut self) -> VhostUserResult<VhostUserProtocolFeatures> {
Ok(self.backend.protocol_features())
}
fn set_protocol_features(&mut self, features: u64) -> VhostUserResult<()> {
self.acked_protocol_features = features;
Ok(())
}
fn get_queue_num(&mut self) -> VhostUserResult<u64> {
Ok(self.num_queues as u64)
}
fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostUserResult<()> {
self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
let vring = self
.vrings
.get(index as usize)
.ok_or_else(|| VhostUserError::InvalidParam)?;
vring.set_enabled(enable);
Ok(())
}
fn get_config(
&mut self,
offset: u32,
size: u32,
_flags: VhostUserConfigFlags,
) -> VhostUserResult<Vec<u8>> {
Ok(self.backend.get_config(offset, size))
}
fn set_config(
&mut self,
offset: u32,
buf: &[u8],
_flags: VhostUserConfigFlags,
) -> VhostUserResult<()> {
self.backend
.set_config(offset, buf)
.map_err(VhostUserError::ReqHandlerError)
}
fn set_backend_req_fd(&mut self, backend: Backend) {
if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() != 0 {
backend.set_reply_ack_flag(true);
}
self.backend.set_backend_req_fd(backend);
}
fn get_inflight_fd(
&mut self,
_inflight: &vhost::vhost_user::message::VhostUserInflight,
) -> VhostUserResult<(vhost::vhost_user::message::VhostUserInflight, File)> {
Err(VhostUserError::InvalidOperation("not supported"))
}
fn set_inflight_fd(
&mut self,
_inflight: &vhost::vhost_user::message::VhostUserInflight,
_file: File,
) -> VhostUserResult<()> {
Err(VhostUserError::InvalidOperation("not supported"))
}
fn get_max_mem_slots(&mut self) -> VhostUserResult<u64> {
Ok(MAX_MEM_SLOTS)
}
fn add_mem_region(
&mut self,
region: &VhostUserSingleMemoryRegion,
file: File,
) -> VhostUserResult<()> {
let guest_region = Arc::new(
GuestRegionMmap::new(
region.mmap_region(file)?,
GuestAddress(region.guest_phys_addr),
)
.map_err(|e| {
VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
})?,
);
let mem = self
.atomic_mem
.memory()
.insert_region(guest_region)
.map_err(|e| {
VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
})?;
self.atomic_mem.lock().unwrap().replace(mem);
self.backend
.update_memory(self.atomic_mem.clone())
.map_err(|e| {
VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
})?;
self.mappings.push(AddrMapping {
vmm_addr: region.user_addr,
size: region.memory_size,
gpa_base: region.guest_phys_addr,
});
Ok(())
}
fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> VhostUserResult<()> {
let (mem, _) = self
.atomic_mem
.memory()
.remove_region(GuestAddress(region.guest_phys_addr), region.memory_size)
.map_err(|e| {
VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
})?;
self.atomic_mem.lock().unwrap().replace(mem);
self.backend
.update_memory(self.atomic_mem.clone())
.map_err(|e| {
VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
})?;
self.mappings
.retain(|mapping| mapping.gpa_base != region.guest_phys_addr);
Ok(())
}
}
impl<T: VhostUserBackend> Drop for VhostUserHandler<T> {
fn drop(&mut self) {
self.send_exit_event();
for thread in self.worker_threads.drain(..) {
if let Err(e) = thread.join() {
error!("Error in vring worker: {:?}", e);
}
}
}
}