use std::{net::SocketAddr, sync::Arc, time::Duration};
use crate::traits::*;
use crate::{CoarseInstant, CoarseTimeProvider};
use async_trait::async_trait;
use educe::Educe;
use futures::{future::FutureObj, task::Spawn};
use std::io::Result as IoResult;
use std::time::{Instant, SystemTime};
#[derive(Educe)]
#[educe(Clone)] pub struct CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR> {
inner: Arc<Inner<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR>>,
}
struct Inner<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR> {
spawn: SpawnR,
sleep: SleepR,
coarse_time: CoarseTimeR,
tcp: TcpR,
tls: TlsR,
udp: UdpR,
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR>
CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR>
{
pub fn new(
spawn: SpawnR,
sleep: SleepR,
coarse_time: CoarseTimeR,
tcp: TcpR,
tls: TlsR,
udp: UdpR,
) -> Self {
#[allow(clippy::arc_with_non_send_sync)]
CompoundRuntime {
inner: Arc::new(Inner {
spawn,
sleep,
coarse_time,
tcp,
tls,
udp,
}),
}
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR> Spawn
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR>
where
SpawnR: Spawn,
{
#[inline]
#[track_caller]
fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
self.inner.spawn.spawn_obj(future)
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR> BlockOn
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR>
where
SpawnR: BlockOn,
SleepR: Clone + Send + Sync + 'static,
CoarseTimeR: Clone + Send + Sync + 'static,
TcpR: Clone + Send + Sync + 'static,
TlsR: Clone + Send + Sync + 'static,
UdpR: Clone + Send + Sync + 'static,
{
#[inline]
#[track_caller]
fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
self.inner.spawn.block_on(future)
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR> SleepProvider
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR>
where
SleepR: SleepProvider,
SpawnR: Clone + Send + Sync + 'static,
CoarseTimeR: Clone + Send + Sync + 'static,
TcpR: Clone + Send + Sync + 'static,
TlsR: Clone + Send + Sync + 'static,
UdpR: Clone + Send + Sync + 'static,
{
type SleepFuture = SleepR::SleepFuture;
#[inline]
fn sleep(&self, duration: Duration) -> Self::SleepFuture {
self.inner.sleep.sleep(duration)
}
#[inline]
fn now(&self) -> Instant {
self.inner.sleep.now()
}
#[inline]
fn wallclock(&self) -> SystemTime {
self.inner.sleep.wallclock()
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR> CoarseTimeProvider
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR>
where
CoarseTimeR: CoarseTimeProvider,
SleepR: Clone + Send + Sync + 'static,
SpawnR: Clone + Send + Sync + 'static,
CoarseTimeR: Clone + Send + Sync + 'static,
TcpR: Clone + Send + Sync + 'static,
TlsR: Clone + Send + Sync + 'static,
UdpR: Clone + Send + Sync + 'static,
{
#[inline]
fn now_coarse(&self) -> CoarseInstant {
self.inner.coarse_time.now_coarse()
}
}
#[async_trait]
impl<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR> TcpProvider
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR>
where
TcpR: TcpProvider,
SpawnR: Send + Sync + 'static,
SleepR: Send + Sync + 'static,
CoarseTimeR: Send + Sync + 'static,
TcpR: Send + Sync + 'static,
TlsR: Send + Sync + 'static,
UdpR: Send + Sync + 'static,
{
type TcpStream = TcpR::TcpStream;
type TcpListener = TcpR::TcpListener;
#[inline]
async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::TcpStream> {
self.inner.tcp.connect(addr).await
}
#[inline]
async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::TcpListener> {
self.inner.tcp.listen(addr).await
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR, S> TlsProvider<S>
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR>
where
TcpR: TcpProvider,
TlsR: TlsProvider<S>,
SleepR: Clone + Send + Sync + 'static,
CoarseTimeR: Clone + Send + Sync + 'static,
SpawnR: Clone + Send + Sync + 'static,
UdpR: Clone + Send + Sync + 'static,
{
type Connector = TlsR::Connector;
type TlsStream = TlsR::TlsStream;
#[inline]
fn tls_connector(&self) -> Self::Connector {
self.inner.tls.tls_connector()
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR> std::fmt::Debug
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompoundRuntime").finish_non_exhaustive()
}
}
#[async_trait]
impl<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR> UdpProvider
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, TlsR, UdpR>
where
UdpR: UdpProvider,
SpawnR: Send + Sync + 'static,
SleepR: Send + Sync + 'static,
CoarseTimeR: Send + Sync + 'static,
TcpR: Send + Sync + 'static,
TlsR: Send + Sync + 'static,
UdpR: Send + Sync + 'static,
{
type UdpSocket = UdpR::UdpSocket;
#[inline]
async fn bind(&self, addr: &SocketAddr) -> IoResult<Self::UdpSocket> {
self.inner.udp.bind(addr).await
}
}