use std::{cell::RefCell, io::ErrorKind, rc::Rc, time::Duration};
use indexmap::IndexSet;
use zerocopy::{FromBytes, IntoBytes};
use crate::{
context::VmiContext, os::VmiOs, AccessContext, Architecture, PageFault, PageFaults,
TranslationMechanism, Va, VmiCore, VmiDriver, VmiError, VmiHandler,
};
pub struct VmiSession<Driver, Os>
where
Driver: VmiDriver,
Os: VmiOs<Driver>,
{
pub(crate) core: VmiCore<Driver>,
pub(crate) os: Os,
}
impl<Driver, Os> std::ops::Deref for VmiSession<Driver, Os>
where
Driver: VmiDriver,
Os: VmiOs<Driver>,
{
type Target = VmiCore<Driver>;
fn deref(&self) -> &Self::Target {
&self.core
}
}
impl<Driver, Os> VmiSession<Driver, Os>
where
Driver: VmiDriver,
Os: VmiOs<Driver>,
{
pub fn new(core: VmiCore<Driver>, os: Os) -> Self {
Self { core, os }
}
pub fn core(&self) -> &VmiCore<Driver> {
&self.core
}
pub fn underlying_os(&self) -> &Os {
&self.os
}
pub fn os(&self) -> VmiOsSession<Driver, Os> {
VmiOsSession(self)
}
pub fn prober<'a>(
&'a self,
restricted: &IndexSet<PageFault>,
) -> VmiSessionProber<'a, Driver, Os> {
VmiSessionProber::new(self, restricted)
}
pub fn wait_for_event(
&self,
timeout: Duration,
handler: &mut impl VmiHandler<Driver, Os>,
) -> Result<(), VmiError> {
self.core.wait_for_event(
timeout,
Box::new(|event| handler.handle_event(VmiContext::new(self, event))),
)
}
pub fn handle<Handler>(
&self,
handler_factory: impl FnOnce(&VmiSession<Driver, Os>) -> Result<Handler, VmiError>,
) -> Result<(), VmiError>
where
Handler: VmiHandler<Driver, Os>,
{
let mut handler = handler_factory(self)?;
while !handler.finished() {
match self.wait_for_event(Duration::from_millis(5000), &mut handler) {
Err(VmiError::Timeout) => {
tracing::trace!("timeout");
}
Err(VmiError::Io(err)) if err.kind() == ErrorKind::Interrupted => {
tracing::info!("interrupted");
break;
}
Err(err) => return Err(err),
Ok(_) => {}
}
}
tracing::trace!("disabling monitor");
self.core.reset_state()?;
tracing::trace!(pending_events = self.events_pending());
let _pause_guard = self.pause_guard()?;
if self.events_pending() > 0 {
match self.wait_for_event(Duration::from_millis(0), &mut handler) {
Err(VmiError::Timeout) => {
tracing::trace!("timeout");
}
Err(err) => return Err(err),
Ok(_) => {}
}
}
Ok(())
}
}
pub struct VmiOsSession<'a, Driver, Os>(pub(crate) &'a VmiSession<Driver, Os>)
where
Driver: VmiDriver,
Os: VmiOs<Driver>;
impl<Driver, Os> VmiOsSession<'_, Driver, Os>
where
Driver: VmiDriver,
Os: VmiOs<Driver>,
{
pub fn core(&self) -> &VmiSession<Driver, Os> {
self.0
}
pub fn underlying_os(&self) -> &Os {
self.0.underlying_os()
}
}
pub struct VmiSessionProber<'a, Driver, Os>
where
Driver: VmiDriver,
Os: VmiOs<Driver>,
{
pub(crate) session: &'a VmiSession<Driver, Os>,
pub(crate) restricted: Rc<IndexSet<PageFault>>,
pub(crate) page_faults: Rc<RefCell<IndexSet<PageFault>>>,
}
impl<Driver, Os> std::ops::Deref for VmiSessionProber<'_, Driver, Os>
where
Driver: VmiDriver,
Os: VmiOs<Driver>,
{
type Target = VmiSession<Driver, Os>;
fn deref(&self) -> &Self::Target {
self.session
}
}
impl<'a, Driver, Os> VmiSessionProber<'a, Driver, Os>
where
Driver: VmiDriver,
Os: VmiOs<Driver>,
{
pub fn new(session: &'a VmiSession<Driver, Os>, restricted: &IndexSet<PageFault>) -> Self {
Self {
session,
restricted: Rc::new(restricted.clone()),
page_faults: Rc::new(RefCell::new(IndexSet::new())),
}
}
#[tracing::instrument(skip_all)]
pub fn error_for_page_faults(&self) -> Result<(), VmiError> {
let pfs = self.page_faults.borrow();
let new_pfs = &*pfs - &self.restricted;
if !new_pfs.is_empty() {
tracing::trace!(?new_pfs);
return Err(VmiError::page_faults(new_pfs));
}
Ok(())
}
pub fn os(&'a self) -> VmiOsSessionProber<'a, Driver, Os> {
VmiOsSessionProber(self)
}
pub fn read(
&self,
ctx: impl Into<AccessContext>,
buffer: &mut [u8],
) -> Result<Option<()>, VmiError> {
let ctx = ctx.into();
self.check_result_range(self.session.core().read(ctx, buffer), ctx, buffer.len())
}
pub fn read_u8(&self, ctx: impl Into<AccessContext>) -> Result<Option<u8>, VmiError> {
let ctx = ctx.into();
self.check_result_range(self.session.core().read_u8(ctx), ctx, size_of::<u8>())
}
pub fn read_u16(&self, ctx: impl Into<AccessContext>) -> Result<Option<u16>, VmiError> {
let ctx = ctx.into();
self.check_result_range(self.session.core().read_u16(ctx), ctx, size_of::<u16>())
}
pub fn read_u32(&self, ctx: impl Into<AccessContext>) -> Result<Option<u32>, VmiError> {
let ctx = ctx.into();
self.check_result_range(self.session.core().read_u32(ctx), ctx, size_of::<u32>())
}
pub fn read_u64(&self, ctx: impl Into<AccessContext>) -> Result<Option<u64>, VmiError> {
let ctx = ctx.into();
self.check_result_range(self.session.core().read_u64(ctx), ctx, size_of::<u64>())
}
pub fn read_va(
&self,
ctx: impl Into<AccessContext>,
address_width: usize,
) -> Result<Option<Va>, VmiError> {
let ctx = ctx.into();
self.check_result_range(
self.session.core().read_va(ctx, address_width),
ctx,
address_width,
)
}
pub fn read_va32(&self, ctx: impl Into<AccessContext>) -> Result<Option<Va>, VmiError> {
let ctx = ctx.into();
self.check_result_range(self.session.core().read_va32(ctx), ctx, size_of::<u32>())
}
pub fn read_va64(&self, ctx: impl Into<AccessContext>) -> Result<Option<Va>, VmiError> {
let ctx = ctx.into();
self.check_result_range(self.session.core().read_va64(ctx), ctx, size_of::<u64>())
}
pub fn read_string_bytes(
&self,
ctx: impl Into<AccessContext>,
) -> Result<Option<Vec<u8>>, VmiError> {
let ctx = ctx.into();
self.check_result_range(self.session.core().read_string_bytes(ctx), ctx, 1)
}
pub fn read_wstring_bytes(
&self,
ctx: impl Into<AccessContext>,
) -> Result<Option<Vec<u16>>, VmiError> {
let ctx = ctx.into();
self.check_result_range(self.session.core().read_wstring_bytes(ctx), ctx, 2)
}
pub fn read_string(&self, ctx: impl Into<AccessContext>) -> Result<Option<String>, VmiError> {
let ctx = ctx.into();
self.check_result_range(self.session.core().read_string(ctx), ctx, 1)
}
pub fn read_wstring(&self, ctx: impl Into<AccessContext>) -> Result<Option<String>, VmiError> {
let ctx = ctx.into();
self.check_result_range(self.session.core().read_wstring(ctx), ctx, 2)
}
pub fn read_struct<T>(&self, ctx: impl Into<AccessContext>) -> Result<Option<T>, VmiError>
where
T: IntoBytes + FromBytes,
{
let ctx = ctx.into();
self.check_result_range(self.session.core().read_struct(ctx), ctx, size_of::<T>())
}
pub fn check_result<T>(&self, result: Result<T, VmiError>) -> Result<Option<T>, VmiError> {
match result {
Ok(value) => Ok(Some(value)),
Err(VmiError::PageFault(pfs)) => {
self.check_restricted(pfs);
Ok(None)
}
Err(err) => Err(err),
}
}
fn check_result_range<T>(
&self,
result: Result<T, VmiError>,
ctx: AccessContext,
length: usize,
) -> Result<Option<T>, VmiError> {
match result {
Ok(value) => Ok(Some(value)),
Err(VmiError::PageFault(pfs)) => {
debug_assert_eq!(pfs.len(), 1);
self.check_restricted_range(pfs[0], ctx, length);
Ok(None)
}
Err(err) => Err(err),
}
}
fn check_restricted(&self, pfs: PageFaults) {
let mut page_faults = self.page_faults.borrow_mut();
for pf in pfs {
if !self.restricted.contains(&pf) {
tracing::trace!(va = %pf.address, "page fault");
page_faults.insert(pf);
}
else {
tracing::trace!(va = %pf.address, "page fault (restricted)");
}
}
}
fn check_restricted_range(&self, pf: PageFault, ctx: AccessContext, mut length: usize) {
let mut page_faults = self.page_faults.borrow_mut();
if length == 0 {
length = 1;
}
let pf_page = pf.address.0 >> Driver::Architecture::PAGE_SHIFT;
let last_page = (ctx.address + length as u64 - 1) >> Driver::Architecture::PAGE_SHIFT;
let number_of_pages = last_page - pf_page + 1;
let pf_address_aligned = Va(pf_page << Driver::Architecture::PAGE_SHIFT);
let last_address_aligned = Va(last_page << Driver::Architecture::PAGE_SHIFT);
if number_of_pages > 1 {
tracing::debug!(
from = %pf_address_aligned,
to = %last_address_aligned,
number_of_pages,
"page fault (range)"
);
if number_of_pages >= 4096 {
tracing::warn!(
from = %pf_address_aligned,
to = %last_address_aligned,
number_of_pages,
"page fault range too large"
);
}
}
for i in 0..number_of_pages {
debug_assert_eq!(
pf.root,
match ctx.mechanism {
TranslationMechanism::Paging { root: Some(root) } => root,
_ => panic!("page fault root doesn't match the context root"),
}
);
let pf = PageFault {
address: pf_address_aligned + i * Driver::Architecture::PAGE_SIZE,
root: pf.root,
};
if !self.restricted.contains(&pf) {
tracing::trace!(va = %pf.address, "page fault");
page_faults.insert(pf);
}
else {
tracing::trace!(va = %pf.address, "page fault (restricted)");
}
}
}
}
pub struct VmiOsSessionProber<'a, Driver, Os>(pub(crate) &'a VmiSessionProber<'a, Driver, Os>)
where
Driver: VmiDriver,
Os: VmiOs<Driver>;
impl<Driver, Os> VmiOsSessionProber<'_, Driver, Os>
where
Driver: VmiDriver,
Os: VmiOs<Driver>,
{
pub fn core(&self) -> &VmiSessionProber<'_, Driver, Os> {
self.0
}
pub fn underlying_os(&self) -> &Os {
self.0.underlying_os()
}
}