Skip to main content

rustrade_data/streams/builder/
mod.rs

1use super::Streams;
2use crate::{
3    Identifier,
4    error::DataError,
5    exchange::StreamSelector,
6    instrument::InstrumentData,
7    streams::{
8        consumer::{MarketStreamResult, STREAM_RECONNECTION_POLICY, init_market_stream},
9        reconnect::stream::ReconnectingStream,
10    },
11    subscription::{Subscription, SubscriptionKind},
12};
13use rustrade_instrument::exchange::ExchangeId;
14use rustrade_integration::{Validator, channel::Channel};
15use std::{
16    collections::HashMap,
17    fmt::{Debug, Display},
18    future::Future,
19    pin::Pin,
20};
21
22/// Defines the [`MultiStreamBuilder`](multi::MultiStreamBuilder) API for ergonomically
23/// initialising a common [`Streams<Output>`](Streams) from multiple
24/// [`StreamBuilder<SubscriptionKind>`](StreamBuilder)s.
25pub mod multi;
26
27/// Defines the [`DynamicStreams`](dynamic::DynamicStreams) API for initialising an arbitrary number
28/// of `MarketStream`s from the [`ExchangeId`] and [`SubKind`](crate::subscription::SubKind) enums, rather than concrete
29/// types.
30pub mod dynamic;
31
32/// Communicative type alias representing the [`Future`] result of a [`Subscription`] validation
33/// call generated whilst executing [`StreamBuilder::subscribe`].
34pub type SubscribeFuture = Pin<Box<dyn Future<Output = Result<(), DataError>>>>;
35
36/// Builder to configure and initialise a [`Streams<MarketEvent<SubscriptionKind::Event>`](Streams) instance
37/// for a specific [`SubscriptionKind`].
38#[derive(Default)]
39pub struct StreamBuilder<InstrumentKey, Kind>
40where
41    Kind: SubscriptionKind,
42{
43    pub channels: HashMap<ExchangeId, Channel<MarketStreamResult<InstrumentKey, Kind::Event>>>,
44    pub futures: Vec<SubscribeFuture>,
45}
46
47impl<InstrumentKey, Kind> Debug for StreamBuilder<InstrumentKey, Kind>
48where
49    InstrumentKey: Debug,
50    Kind: SubscriptionKind,
51{
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("StreamBuilder<InstrumentKey, SubscriptionKind>")
54            .field("channels", &self.channels)
55            .field("num_futures", &self.futures.len())
56            .finish()
57    }
58}
59
60impl<InstrumentKey, Kind> StreamBuilder<InstrumentKey, Kind>
61where
62    Kind: SubscriptionKind,
63{
64    /// Construct a new [`Self`].
65    pub fn new() -> Self {
66        Self {
67            channels: HashMap::new(),
68            futures: Vec::new(),
69        }
70    }
71
72    /// Add a collection of [`Subscription`]s to the [`StreamBuilder`] that will be actioned on
73    /// a distinct [`WebSocket`](rustrade_integration::protocol::websocket::WebSocket) connection.
74    ///
75    /// The `subscriber` handles the WebSocket connection and authentication.
76    /// For unauthenticated exchanges, use [`WebSocketSubscriber`](crate::subscriber::WebSocketSubscriber).
77    /// For authenticated exchanges like Alpaca, use the exchange-specific subscriber with credentials.
78    ///
79    /// Note that [`Subscription`]s are not actioned until the
80    /// [`init()`](StreamBuilder::init()) method is invoked.
81    pub fn subscribe<SubIter, Sub, Exchange, Instrument>(
82        mut self,
83        subscriber: Exchange::Subscriber,
84        subscriptions: SubIter,
85    ) -> Self
86    where
87        SubIter: IntoIterator<Item = Sub>,
88        Sub: Into<Subscription<Exchange, Instrument, Kind>>,
89        Exchange: StreamSelector<Instrument, Kind> + Ord + Send + Sync + 'static,
90        Instrument: InstrumentData<Key = InstrumentKey> + Ord + Display + 'static,
91        Instrument::Key: Debug + Clone + Send + 'static,
92        Kind: Ord + Display + Send + Sync + 'static,
93        Kind::Event: Clone + Send,
94        Subscription<Exchange, Instrument, Kind>:
95            Identifier<Exchange::Channel> + Identifier<Exchange::Market>,
96    {
97        // Construct Vec<Subscriptions> from input SubIter
98        let subscriptions = subscriptions.into_iter().map(Sub::into).collect::<Vec<_>>();
99
100        // Acquire channel Sender to send Market<Kind::Event> from consumer loop to user
101        // '--> Add ExchangeChannel Entry if this Exchange <--> SubscriptionKind combination is new
102        let exchange_tx = self.channels.entry(Exchange::ID).or_default().tx.clone();
103
104        // Add Future that once awaited will yield the Result<(), SocketError> of subscribing
105        self.futures.push(Box::pin(async move {
106            // Validate Subscriptions
107            let mut subscriptions = subscriptions
108                .into_iter()
109                .map(Subscription::validate)
110                .collect::<Result<Vec<_>, _>>()?;
111
112            // Remove duplicate Subscriptions
113            subscriptions.sort();
114            subscriptions.dedup();
115
116            // Initialise a MarketEvent `ReconnectingStream`
117            let stream =
118                init_market_stream(STREAM_RECONNECTION_POLICY, subscriber, subscriptions).await?;
119
120            // Forward MarketEvents to ExchangeTx
121            tokio::spawn(stream.forward_to(exchange_tx));
122
123            Ok(())
124        }));
125
126        self
127    }
128
129    /// Spawn a [`MarketStreamResult<SubscriptionKind::Event>`](MarketStreamResult) consumer loop
130    /// for each collection of [`Subscription`]s added to [`StreamBuilder`] via the
131    /// [`subscribe()`](StreamBuilder::subscribe()) method.
132    ///
133    /// Each consumer loop distributes consumed [`MarketStreamResult`] to
134    /// the [`Streams`] `HashMap` returned by this method.
135    pub async fn init(
136        self,
137    ) -> Result<Streams<MarketStreamResult<InstrumentKey, Kind::Event>>, DataError> {
138        // Await Stream initialisation perpetual and ensure success
139        futures::future::try_join_all(self.futures).await?;
140
141        // Construct Streams using each ExchangeChannel receiver
142        Ok(Streams {
143            streams: self
144                .channels
145                .into_iter()
146                .map(|(exchange, channel)| (exchange, channel.rx))
147                .collect(),
148        })
149    }
150}