1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
//! Define a [`CompoundRuntime`] part that can be built from several component
//! pieces.

use std::{net::SocketAddr, sync::Arc, time::Duration};

use crate::traits::*;
use async_trait::async_trait;
use futures::{future::FutureObj, task::Spawn};
use std::io::Result as IoResult;

/// A runtime made of several parts, each of which implements one trait-group.
///
/// The `SpawnR` component should implements [`Spawn`] and [`BlockOn`];
/// the `SleepR` component should implement [`SleepProvider`]; the `TcpR`
/// component should implement [`TcpProvider`]; and the `TlsR` component should
/// implement [`TlsProvider`].
///
/// You can use this structure to create new runtimes in two ways: either by
/// overriding a single part of an existing runtime, or by building an entirely
/// new runtime from pieces.
pub struct CompoundRuntime<SpawnR, SleepR, TcpR, TlsR> {
    /// The actual collection of Runtime objects.
    ///
    /// We wrap this in an Arc rather than requiring that each item implement
    /// Clone, though we could change our minds later on.
    inner: Arc<Inner<SpawnR, SleepR, TcpR, TlsR>>,
}

// We have to provide this ourselves, since derive(Clone) wrongly infers a
// `where S: Clone` bound (from the generic argument).
impl<SpawnR, SleepR, TcpR, TlsR> Clone for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR> {
    fn clone(&self) -> Self {
        Self {
            inner: Arc::clone(&self.inner),
        }
    }
}

/// A collection of objects implementing that traits that make up a [`Runtime`]
struct Inner<SpawnR, SleepR, TcpR, TlsR> {
    /// A `Spawn` and `BlockOn` implementation.
    spawn: SpawnR,
    /// A `SleepProvider` implementation.
    sleep: SleepR,
    /// A `TcpProvider` implementation
    tcp: TcpR,
    /// A `TcpProvider<TcpR::TcpStream>` implementation.
    tls: TlsR,
}

impl<SpawnR, SleepR, TcpR, TlsR> CompoundRuntime<SpawnR, SleepR, TcpR, TlsR> {
    /// Construct a new CompoundRuntime from its components.
    pub fn new(spawn: SpawnR, sleep: SleepR, tcp: TcpR, tls: TlsR) -> Self {
        CompoundRuntime {
            inner: Arc::new(Inner {
                spawn,
                sleep,
                tcp,
                tls,
            }),
        }
    }
}

impl<SpawnR, SleepR, TcpR, TlsR> Spawn for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR>
where
    SpawnR: Spawn,
{
    #[inline]
    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
        self.inner.spawn.spawn_obj(future)
    }
}

impl<SpawnR, SleepR, TcpR, TlsR> BlockOn for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR>
where
    SpawnR: BlockOn,
{
    #[inline]
    fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
        self.inner.spawn.block_on(future)
    }
}

impl<SpawnR, SleepR, TcpR, TlsR> SleepProvider for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR>
where
    SleepR: SleepProvider,
{
    type SleepFuture = SleepR::SleepFuture;

    #[inline]
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
        self.inner.sleep.sleep(duration)
    }
}

#[async_trait]
impl<SpawnR, SleepR, TcpR, TlsR> TcpProvider for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR>
where
    TcpR: TcpProvider,
    SpawnR: Send + Sync + 'static,
    SleepR: Send + Sync + 'static,
    TcpR: Send + Sync + 'static,
    TlsR: Send + Sync + 'static,
{
    type TcpStream = TcpR::TcpStream;

    type TcpListener = TcpR::TcpListener;

    #[inline]
    async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::TcpStream> {
        self.inner.tcp.connect(addr).await
    }

    #[inline]
    async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::TcpListener> {
        self.inner.tcp.listen(addr).await
    }
}

impl<SpawnR, SleepR, TcpR, TlsR, S> TlsProvider<S> for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR>
where
    TcpR: TcpProvider,
    TlsR: TlsProvider<S>,
{
    type Connector = TlsR::Connector;
    type TlsStream = TlsR::TlsStream;

    #[inline]
    fn tls_connector(&self) -> Self::Connector {
        self.inner.tls.tls_connector()
    }
}

impl<SpawnR, SleepR, TcpR, TlsR> std::fmt::Debug for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("CompoundRuntime").finish_non_exhaustive()
    }
}