spectacles_gateway/
manager.rs

1use std::{
2    sync::Arc,
3    time::Duration,
4    time::Instant
5};
6
7use futures::{
8    future::Future,
9    Poll,
10    Stream,
11    sync::mpsc::{unbounded, UnboundedReceiver}
12};
13use futures::future::Loop;
14use futures::sync::mpsc::UnboundedSender;
15use hashbrown::HashMap;
16use parking_lot::{Mutex, RwLock};
17use tokio::timer::Delay;
18use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
19
20use spectacles_model::gateway::{GatewayBot, Opcodes, ReceivePacket};
21
22use crate::{
23    constants::API_BASE,
24    errors::*,
25    queue::{MessageSink, MessageSinkError},
26    shard::{Shard, ShardAction}
27};
28
29/// The strategy in which you would like to spawn shards.
30#[derive(Clone)]
31pub enum ShardStrategy {
32    /// The spawner will automatically spawn shards based on the amount recommended by Discord.
33    Recommended,
34    /// Spawns shards according to the amount specified, starting from shard 0.
35    SpawnAmount(usize)
36}
37
38#[derive(Clone)]
39/// Information about a Discord Gateway event received for a shard.
40pub struct ShardEvent {
41    /// The shard which emitted this event.
42    pub shard: ManagerShard,
43    /// The Discord Gateway packet that the event contains.
44    pub packet: ReceivePacket,
45}
46
47/// A collection of shards, keyed by shard ID.
48pub type ShardMap = HashMap<usize, Arc<Mutex<Shard>>>;
49/// An alias for a shard spawned with the sharding manager.
50pub type ManagerShard = Arc<Mutex<Shard>>;
51type MessageStream = UnboundedReceiver<(ManagerShard, TungsteniteMessage)>;
52
53/// A stream of shards being spawned and emitting the ready event.
54pub struct Spawner {
55    inner: UnboundedReceiver<ManagerShard>
56}
57
58impl Spawner {
59    fn new(receiver: UnboundedReceiver<ManagerShard>) -> Self {
60        Spawner { inner: receiver }
61    }
62}
63
64impl Stream for Spawner {
65    type Item = ManagerShard;
66    type Error = ();
67
68    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
69        self.inner.poll()
70    }
71}
72
73/// A stream of incoming Discord events for a shard.
74pub struct EventHandler {
75    inner: UnboundedReceiver<ShardEvent>
76}
77
78impl EventHandler {
79    fn new(receiver: UnboundedReceiver<ShardEvent>) -> Self {
80        EventHandler { inner: receiver }
81    }
82}
83
84impl Stream for EventHandler {
85    type Item = ShardEvent;
86    type Error = ();
87
88    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
89        self.inner.poll()
90    }
91}
92
93#[derive(Clone)]
94struct SpawnerLoop {
95    shardmap: Arc<RwLock<ShardMap>>,
96    sink_tx: UnboundedSender<(ManagerShard, TungsteniteMessage)>,
97    current: usize,
98    total: usize,
99    sender: UnboundedSender<ManagerShard>,
100    token: String,
101    ws: String,
102}
103
104/// The central hub for all shards, where shards are spawned and maintained.
105pub struct ShardManager {
106    /// The token used by this manager to spawn shards.
107    pub token: String,
108    /// The total amount of shards that this manager will attempt to spawn.
109    pub total_shards: usize,
110    /// A collection of shards that have been spawned.
111    pub shards: Arc<RwLock<ShardMap>>,
112    event_sender: Option<UnboundedSender<ShardEvent>>,
113    message_stream: Option<MessageStream>,
114    ws_uri: String
115}
116
117impl ShardManager {
118    /// Creates a new cluster, with the provided Discord API token.
119    pub fn new(token: String, strategy: ShardStrategy) -> impl Future<Item=ShardManager, Error=Error> {
120        let token = if token.starts_with("Bot ") {
121            token
122        } else {
123            format!("Bot {}", token)
124        };
125
126        use reqwest::r#async::Client;
127        Client::new().get(format!("{}/gateway/bot", API_BASE).as_str())
128            .header("Authorization", token.clone()).send()
129            .and_then(|mut resp| resp.json::<GatewayBot>())
130            .map_err(Error::from)
131            .map(move |gb| {
132                let shard_count = match strategy {
133                    ShardStrategy::Recommended => gb.shards,
134                    ShardStrategy::SpawnAmount(int) => int
135                };
136
137                Self {
138                    token,
139                    total_shards: shard_count,
140                    shards: Arc::new(RwLock::new(HashMap::new())),
141                    event_sender: None,
142                    message_stream: None,
143                    ws_uri: gb.url
144                }
145            })
146    }
147
148    /// Spawns shards up to the specified amount and identifies them with Discord.
149    pub fn start_spawn(&mut self) -> (Spawner, EventHandler) {
150        let (sender, receiver) = unbounded();
151        self.message_stream = Some(receiver);
152        let (tx, rx) = unbounded();
153        debug!("Attempting to spawn {} shards.", &self.total_shards);
154        let initial = SpawnerLoop {
155            current: 0,
156            shardmap: Arc::clone(&self.shards),
157            sink_tx: sender.clone(),
158            sender: tx,
159            total: self.total_shards,
160            token: self.token.clone(),
161            ws: self.ws_uri.clone(),
162        };
163
164        tokio::spawn(futures::future::loop_fn(initial, move |state| {
165            Delay::new(Instant::now() + Duration::from_secs(6)).from_err()
166                .map(|_| state)
167                .and_then(move |mut state| {
168                    Shard::new(state.token.clone(), [state.current, state.total], state.ws.clone())
169                        .map(move |shard| {
170                            let wrapped = ManagerShard::new(Mutex::new(shard));
171                            state.shardmap.write().insert(wrapped.lock().info[0], Arc::clone(&wrapped));
172                            let sink = MessageSink {
173                                shard: Arc::clone(&wrapped),
174                                sender: state.sink_tx.clone(),
175                            };
176                            let split = wrapped.lock().stream.lock().take().unwrap().map_err(MessageSinkError::from);
177                            tokio::spawn(split.forward(sink)
178                                .map(|_| ())
179                                .map_err(|e| error!("Failed to forward shard messages to the sink. {:?}", e))
180                            );
181                            state.sender.unbounded_send(wrapped).expect("Failed to send shard to stream");
182                            state.current += 1;
183
184                            state
185                        })
186                })
187                .map(|state| {
188                    if state.current == state.total {
189                        Loop::Break(())
190                    } else {
191                        Loop::Continue(state)
192                    }
193                })
194        }).map_err(|err| {
195            error!("Failed in sharding process. {:?}", err);
196        }));
197
198        (Spawner::new(rx), self.start_event_stream())
199    }
200
201    fn start_event_stream(&mut self) -> EventHandler {
202        let stream = self.message_stream.take().unwrap();
203        let (sender, receiver) = unbounded();
204        self.event_sender = Some(sender.clone());
205
206        tokio::spawn(stream.for_each(move |(shard, message)| {
207            trace!("Websocket message received: {:?}", &message);
208            let event = shard.lock().resolve_packet(&message).expect("Failed to parse the shard message");
209            if let Opcodes::Dispatch = event.op {
210                sender.unbounded_send(ShardEvent {
211                    packet: event.clone(),
212                    shard: Arc::clone(&shard),
213                }).expect("Failed to send shard event to stream");
214            };
215            let action = shard.lock().fulfill_gateway(event.clone()).expect("Failed to fufill gateway message");
216
217            match action {
218                ShardAction::Autoreconnect => {
219                    let sd = Arc::clone(&shard);
220                    tokio::spawn(shard.lock().autoreconnect().map(move |_| {
221                        info!("[Shard {}] Auto reconnection successful.", sd.lock().info[0]);
222                    }).map_err(|err| {
223                        error!("Failed to auto reconnect shard. {}", err);
224                    }));
225                },
226                ShardAction::Identify => {
227                    let info = shard.lock().info;
228                    debug!("[Shard {}] Identifying with the gateway.", &info[0]);
229                    if let Err(e) = shard.lock().identify() {
230                        warn!("[Shard {}] Failed to identify with gateway. {:?}", &info[0], e);
231                    };
232                },
233                ShardAction::Reconnect => {
234                    info!("[Shard {}] Reconnection successful.", shard.lock().info[0]);
235                },
236                ShardAction::Resume => {
237                    let sd = Arc::clone(&shard);
238                    tokio::spawn(shard.lock().resume().map(move |_| {
239                        debug!("[Shard {}] Successfully resumed session.", sd.lock().info[0]);
240                    }).map_err(|err| {
241                        error!("Shard failed to resume session. {}", err);
242                    }));
243                },
244                _ => {}
245            };
246
247            Ok(())
248        }));
249
250        EventHandler::new(receiver)
251    }
252}