use std::sync::Mutex;
use std::time::Duration;
use crate::context::{ContextError, ContextErrorKind, ContractContext};
use crate::park::{WaitMessage, WaitThread};
use crate::time::Timer;
use crate::{Contract, ContractExt, Status};
use futures::{
future::{FusedFuture, Future},
task::{Context, Poll},
};
use parc::{LockWeak, ParentArc};
#[must_use = "contracts do nothing unless polled or awaited"]
pub struct OptionContract<F, VC, PC, R>
where
VC: ContractContext,
PC: ContractContext,
F: FnOnce((VC, PC)) -> R,
{
runner: WaitThread,
timer: Timer,
void_context: Option<ParentArc<Mutex<VC>>>,
prod_context: Option<ParentArc<Mutex<PC>>>,
on_exe: Option<F>,
}
impl<F, VC, PC, R> OptionContract<F, VC, PC, R>
where
VC: ContractContext,
PC: ContractContext,
F: FnOnce((VC, PC)) -> R,
{
pub fn new(expire: Duration, void_c: VC, prod_c: PC, on_exe: F) -> Self {
Self {
runner: WaitThread::new(),
timer: Timer::new(expire),
void_context: Some(ParentArc::new(Mutex::new(void_c))),
prod_context: Some(ParentArc::new(Mutex::new(prod_c))),
on_exe: Some(on_exe),
}
}
fn poll_prod(&self) -> bool {
match &self.prod_context {
Some(c) => c.as_ref().lock().unwrap().poll_valid(),
None => false,
}
}
pin_utils::unsafe_pinned!(timer: Timer);
pin_utils::unsafe_unpinned!(void_context: Option<ParentArc<Mutex<VC>>>);
pin_utils::unsafe_unpinned!(prod_context: Option<ParentArc<Mutex<PC>>>);
pin_utils::unsafe_unpinned!(on_exe: Option<F>);
}
impl<F, VC, PC, R> Contract for OptionContract<F, VC, PC, R>
where
VC: ContractContext,
PC: ContractContext,
F: FnOnce((VC, PC)) -> R,
{
fn poll_valid(&self) -> bool {
match &self.void_context {
Some(c) => c.as_ref().lock().unwrap().poll_valid(),
None => false,
}
}
fn execute(mut self: std::pin::Pin<&mut Self>) -> Self::Output {
let vlockarc = self
.as_mut()
.void_context()
.take()
.expect("Cannot poll after expiration");
let plockarc = self
.as_mut()
.prod_context()
.take()
.expect("Cannot poll after expiration");
let vcontext = vlockarc.block_into_inner().into_inner().unwrap();
let pcontext = plockarc.block_into_inner().into_inner().unwrap();
let f = self
.as_mut()
.on_exe()
.take()
.expect("Cannot run a contract after expiration");
Status::Completed(f((vcontext, pcontext)))
}
fn void(self: std::pin::Pin<&mut Self>) -> Self::Output {
Status::Terminated
}
}
impl<F, VC, PC, R> ContractExt for OptionContract<F, VC, PC, R>
where
VC: ContractContext,
PC: ContractContext,
F: FnOnce((VC, PC)) -> R,
{
type Context = (LockWeak<Mutex<VC>>, LockWeak<Mutex<PC>>);
fn get_context(&self) -> Result<Self::Context, ContextError> {
match (&self.void_context, &self.prod_context) {
(Some(ref vc), Some(ref pc)) => {
Ok((ParentArc::downgrade(vc), ParentArc::downgrade(pc)))
}
_ => Err(ContextError::from(ContextErrorKind::ExpiredContext)),
}
}
}
impl<F, VC, PC, R> Future for OptionContract<F, VC, PC, R>
where
VC: ContractContext,
PC: ContractContext,
F: FnOnce((VC, PC)) -> R,
{
type Output = Status<R>;
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.runner
.sender()
.send(WaitMessage::WakeIn {
waker: cx.waker().clone(),
duration: Duration::new(0, 100),
})
.unwrap();
let mv = (
self.as_mut().timer().poll(cx),
self.poll_valid(),
self.poll_prod(),
);
match mv {
(Poll::Ready(_), true, true) => Poll::Ready(self.execute()),
(Poll::Ready(_), true, false) => Poll::Ready(self.void()),
(Poll::Pending, true, _) => Poll::Pending,
(_, false, _) => Poll::Ready(self.void()),
}
}
}
impl<F, VC, PC, R> FusedFuture for OptionContract<F, VC, PC, R>
where
VC: ContractContext,
PC: ContractContext,
F: FnOnce((VC, PC)) -> R,
{
fn is_terminated(&self) -> bool {
self.void_context.is_none() || self.prod_context.is_none() || self.on_exe.is_none()
}
}
#[cfg(test)]
mod tests {
use super::OptionContract;
use crate::context::cmp::EqContext;
use crate::{ContractExt, Status};
use std::time::Duration;
#[test]
fn prod_option_contract() {
let vcontext = EqContext(2, 2);
let pcontext = EqContext(2, 2);
let c = OptionContract::new(
Duration::new(1, 0),
vcontext,
pcontext,
|(vcon, pcon)| -> usize { vcon.0 + pcon.0 + 1 },
);
if let Status::Completed(val) = futures::executor::block_on(c) {
assert_eq!(val, 5);
} else {
assert!(false);
}
}
#[test]
fn void_option_contract() {
let vcontext = EqContext(2, 2);
let pcontext = EqContext(2, 2);
let c = OptionContract::new(
Duration::new(1, 0),
vcontext,
pcontext,
|(vcon, pcon)| -> usize { vcon.0 + pcon.0 + 1 },
);
let handle = std::thread::spawn({
let (vcontext, _) = c.get_context().unwrap();
move || match vcontext.upgrade() {
Some(vc) => vc.lock().unwrap().0 += 1,
None => {}
}
});
if let Status::Completed(val) = futures::executor::block_on(c) {
assert_ne!(val, 6);
} else {
assert!(true);
}
handle.join().unwrap();
}
#[test]
fn noprod_option_contract() {
let vcontext = EqContext(2, 2);
let pcontext = EqContext(2, 2);
let c = OptionContract::new(
Duration::new(1, 0),
vcontext,
pcontext,
|(vcon, pcon)| -> usize { vcon.0 + pcon.0 + 1 },
);
let _ = std::thread::spawn({
let (_, pcontext) = c.get_context().unwrap();
move || match pcontext.upgrade() {
Some(pc) => pc.lock().unwrap().0 += 1,
None => {}
}
})
.join();
if let Status::Completed(val) = futures::executor::block_on(c) {
assert_ne!(val, 6);
} else {
assert!(true);
}
}
}