twilight_gateway/shard.rs
1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
10use crate::inflater::Inflater;
11use crate::{
12 channel::{MessageChannel, MessageSender},
13 error::{ReceiveMessageError, ReceiveMessageErrorType},
14 json,
15 latency::Latency,
16 queue::{InMemoryQueue, Queue},
17 ratelimiter::CommandRatelimiter,
18 session::Session,
19 Command, Config, Message, ShardId, API_VERSION,
20};
21use futures_core::Stream;
22use futures_sink::Sink;
23use serde::{de::DeserializeOwned, Deserialize};
24#[cfg(any(
25 feature = "native-tls",
26 feature = "rustls-native-roots",
27 feature = "rustls-platform-verifier",
28 feature = "rustls-webpki-roots"
29))]
30use std::io::ErrorKind as IoErrorKind;
31use std::{
32 env::consts::OS,
33 fmt,
34 future::Future,
35 pin::Pin,
36 str,
37 task::{ready, Context, Poll},
38};
39use tokio::{
40 net::TcpStream,
41 sync::oneshot,
42 time::{self, Duration, Instant, Interval, MissedTickBehavior},
43};
44use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
45use twilight_model::gateway::{
46 event::GatewayEventDeserializer,
47 payload::{
48 incoming::Hello,
49 outgoing::{
50 identify::{IdentifyInfo, IdentifyProperties},
51 Heartbeat, Identify, Resume,
52 },
53 },
54 CloseCode, CloseFrame, Intents, OpCode,
55};
56
57/// URL of the Discord gateway.
58const GATEWAY_URL: &str = "wss://gateway.discord.gg";
59
60/// Query argument with zlib-stream enabled.
61#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
62const COMPRESSION_FEATURES: &str = "&compress=zlib-stream";
63
64/// No query arguments due to compression being disabled.
65#[cfg(not(any(feature = "zlib-stock", feature = "zlib-simd")))]
66const COMPRESSION_FEATURES: &str = "";
67
68/// [`tokio_websockets`] library Websocket connection.
69type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
70
71/// Wrapper struct around an `async fn` with a `Debug` implementation.
72struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, WebsocketError>> + Send>>);
73
74impl fmt::Debug for ConnectionFuture {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 f.debug_tuple("ConnectionFuture")
77 .field(&"<async fn>")
78 .finish()
79 }
80}
81
82/// Close initiator of a websocket connection.
83#[derive(Clone, Debug)]
84enum CloseInitiator {
85 /// Gateway initiated the close.
86 ///
87 /// Contains an optional close code.
88 Gateway(Option<u16>),
89 /// Shard initiated the close.
90 ///
91 /// Contains a close code.
92 Shard(CloseFrame<'static>),
93 /// Transport error initiated the close.
94 Transport,
95}
96
97/// Current state of a [Shard].
98#[derive(Clone, Copy, Debug, Eq, PartialEq)]
99pub enum ShardState {
100 /// Shard is connected to the gateway with an active session.
101 Active,
102 /// Shard is disconnected from the gateway but may reconnect in the future.
103 ///
104 /// The websocket connection may still be open.
105 Disconnected {
106 /// Number of reconnection attempts that have been made.
107 reconnect_attempts: u8,
108 },
109 /// Shard has fatally closed.
110 ///
111 /// Possible reasons may be due to [failed authentication],
112 /// [invalid intents], or other reasons. Refer to the documentation for
113 /// [`CloseCode`] for possible reasons.
114 ///
115 /// [failed authentication]: CloseCode::AuthenticationFailed
116 /// [invalid intents]: CloseCode::InvalidIntents
117 FatallyClosed,
118 /// Shard is waiting to establish or resume a session.
119 Identifying,
120 /// Shard is replaying missed dispatch events.
121 ///
122 /// The shard is considered identified whilst resuming.
123 Resuming,
124}
125
126impl ShardState {
127 /// Determine the connection status from the close code.
128 ///
129 /// Defers to [`CloseCode::can_reconnect`] to determine whether the
130 /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
131 /// the close code is unknown.
132 fn from_close_code(close_code: Option<u16>) -> Self {
133 match close_code.map(CloseCode::try_from) {
134 Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
135 _ => Self::Disconnected {
136 reconnect_attempts: 0,
137 },
138 }
139 }
140
141 /// Whether the shard has disconnected but may reconnect in the future.
142 const fn is_disconnected(self) -> bool {
143 matches!(self, Self::Disconnected { .. })
144 }
145
146 /// Whether the shard is identified with an active session.
147 ///
148 /// `true` if the status is [`Active`] or [`Resuming`].
149 ///
150 /// [`Active`]: Self::Active
151 /// [`Resuming`]: Self::Resuming
152 pub const fn is_identified(self) -> bool {
153 matches!(self, Self::Active | Self::Resuming)
154 }
155}
156
157/// Gateway event with only minimal required data.
158#[derive(Deserialize)]
159struct MinimalEvent<T> {
160 /// Attached data of the gateway event.
161 #[serde(rename = "d")]
162 data: T,
163}
164
165/// Minimal [`Ready`] for light deserialization.
166///
167/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
168#[derive(Deserialize)]
169struct MinimalReady {
170 /// Used for resuming connections.
171 resume_gateway_url: Box<str>,
172 /// ID of the new identified session.
173 session_id: String,
174}
175
176/// Pending outgoing message indicator.
177#[derive(Debug)]
178struct Pending {
179 /// The pending message, if not already sent.
180 gateway_event: Option<Message>,
181 /// Whether the pending gateway event is a heartbeat.
182 is_heartbeat: bool,
183}
184
185impl Pending {
186 /// Constructor for a pending gateway event.
187 const fn text(json: String, is_heartbeat: bool) -> Option<Self> {
188 Some(Self {
189 gateway_event: Some(Message::Text(json)),
190 is_heartbeat,
191 })
192 }
193}
194
195/// Gateway API client responsible for up to 2500 guilds.
196///
197/// Shards are responsible for maintaining the gateway connection by processing
198/// events relevant to the operation of shards---such as requests from the
199/// gateway to re-connect or invalidate a session---and then to pass them on to
200/// the user.
201///
202/// Shards start out disconnected, but will on the first successful call to
203/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
204/// be repeatedly called in order for the shard to maintain its connection and
205/// update its internal state.
206///
207/// Shards go through an [identify queue][`queue`] that rate limits concurrent
208/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
209/// invalidates the shard's session and it is therefore **very important** to
210/// reuse the same queue for all shards.
211///
212/// # Sharding
213///
214/// A shard may not be connected to more than 2500 guilds, so large bots must
215/// split themselves across multiple shards. See the
216/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
217/// more info.
218///
219/// # Examples
220///
221/// Create and start a shard and print new and deleted messages:
222///
223/// ```no_run
224/// use std::env;
225/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
226///
227/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
228/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
229/// // token. Of course, this value may be passed into the program however is
230/// // preferred.
231/// let token = env::var("DISCORD_TOKEN")?;
232/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
233///
234/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
235///
236/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
237/// let Ok(event) = item else {
238/// tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
239///
240/// continue;
241/// };
242///
243/// match event {
244/// Event::MessageCreate(message) => {
245/// println!("message received with content: {}", message.content);
246/// }
247/// Event::MessageDelete(message) => {
248/// println!("message with ID {} deleted", message.id);
249/// }
250/// _ => {}
251/// }
252/// }
253/// # Ok(()) }
254/// ```
255///
256/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
257/// [gateway commands]: Shard::command
258/// [`poll_next`]: Shard::poll_next
259/// [`queue`]: crate::queue
260#[derive(Debug)]
261pub struct Shard<Q = InMemoryQueue> {
262 /// User provided configuration.
263 ///
264 /// Configurations are provided or created in shard initializing via
265 /// [`Shard::new`] or [`Shard::with_config`].
266 config: Config<Q>,
267 /// Future to establish a WebSocket connection with the Gateway.
268 connection_future: Option<ConnectionFuture>,
269 /// Websocket connection, which may be connected to Discord's gateway.
270 ///
271 /// The connection should only be dropped after it has returned `Ok(None)`
272 /// to comply with the WebSocket protocol.
273 connection: Option<Connection>,
274 /// Interval of how often the gateway would like the shard to send
275 /// heartbeats.
276 ///
277 /// The interval is received in the [`GatewayEvent::Hello`] event when
278 /// first opening a new [connection].
279 ///
280 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
281 /// [connection]: Self::connection
282 heartbeat_interval: Option<Interval>,
283 /// Whether an event has been received in the current heartbeat interval.
284 heartbeat_interval_event: bool,
285 /// ID of the shard.
286 id: ShardId,
287 /// Identify queue receiver.
288 identify_rx: Option<oneshot::Receiver<()>>,
289 /// Zlib decompressor.
290 #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
291 inflater: Inflater,
292 /// Potentially pending outgoing message.
293 pending: Option<Pending>,
294 /// Recent heartbeat latency statistics.
295 ///
296 /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
297 /// may have changed, invalidating previous latency statistic.
298 ///
299 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
300 latency: Latency,
301 /// Command ratelimiter, if it was enabled via
302 /// [`Config::ratelimit_messages`].
303 ratelimiter: Option<CommandRatelimiter>,
304 /// Used for resuming connections.
305 resume_url: Option<Box<str>>,
306 /// Active session of the shard.
307 ///
308 /// The shard may not have an active session if it hasn't yet identified and
309 /// received a `READY` dispatch event response.
310 session: Option<Session>,
311 /// Current state of the shard.
312 state: ShardState,
313 /// Messages from the user to be relayed and sent over the Websocket
314 /// connection.
315 user_channel: MessageChannel,
316}
317
318impl Shard {
319 /// Create a new shard with the default configuration.
320 pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
321 Self::with_config(id, Config::new(token, intents))
322 }
323}
324
325impl<Q> Shard<Q> {
326 /// Create a new shard with the provided configuration.
327 pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
328 let session = config.take_session();
329 let mut resume_url = config.take_resume_url();
330 //ensure resume_url is only used if we have a session to resume
331 if session.is_none() {
332 resume_url = None;
333 }
334
335 Self {
336 config,
337 connection_future: None,
338 connection: None,
339 heartbeat_interval: None,
340 heartbeat_interval_event: false,
341 id: shard_id,
342 identify_rx: None,
343 #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
344 inflater: Inflater::new(),
345 pending: None,
346 latency: Latency::new(),
347 ratelimiter: None,
348 resume_url,
349 session,
350 state: ShardState::Disconnected {
351 reconnect_attempts: 0,
352 },
353 user_channel: MessageChannel::new(),
354 }
355 }
356
357 /// Immutable reference to the configuration used to instantiate this shard.
358 pub const fn config(&self) -> &Config<Q> {
359 &self.config
360 }
361
362 /// ID of the shard.
363 pub const fn id(&self) -> ShardId {
364 self.id
365 }
366
367 /// Zlib decompressor statistics.
368 ///
369 /// Reset when reconnecting to the gateway.
370 #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
371 pub const fn inflater(&self) -> &Inflater {
372 &self.inflater
373 }
374
375 /// State of the shard.
376 pub const fn state(&self) -> ShardState {
377 self.state
378 }
379
380 /// Shard latency statistics, including average latency and recent heartbeat
381 /// latency times.
382 ///
383 /// Reset when reconnecting to the gateway.
384 pub const fn latency(&self) -> &Latency {
385 &self.latency
386 }
387
388 /// Statistics about the number of available commands and when the command
389 /// ratelimiter will refresh.
390 ///
391 /// This won't be present if ratelimiting was disabled via
392 /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
393 ///
394 /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
395 pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
396 self.ratelimiter.as_ref()
397 }
398
399 /// Immutable reference to the gateways current resume URL.
400 ///
401 /// A resume URL might not be present if the shard had its session
402 /// invalidated and has not yet reconnected.
403 pub fn resume_url(&self) -> Option<&str> {
404 self.resume_url.as_deref()
405 }
406
407 /// Immutable reference to the active gateway session.
408 ///
409 /// An active session may not be present if the shard had its session
410 /// invalidated and has not yet reconnected.
411 pub const fn session(&self) -> Option<&Session> {
412 self.session.as_ref()
413 }
414
415 /// Queue a command to be sent to the gateway.
416 ///
417 /// Serializes the command and then calls [`send`].
418 ///
419 /// [`send`]: Self::send
420 #[allow(clippy::missing_panics_doc)]
421 pub fn command(&self, command: &impl Command) {
422 self.send(json::to_string(command).expect("serialization cannot fail"));
423 }
424
425 /// Queue a JSON encoded gateway event to be sent to the gateway.
426 #[allow(clippy::missing_panics_doc)]
427 pub fn send(&self, json: String) {
428 self.user_channel
429 .command_tx
430 .send(json)
431 .expect("channel open");
432 }
433
434 /// Queue a websocket close frame.
435 ///
436 /// Invalidates the session and shows the application's bot as offline if
437 /// the close frame code is `1000` or `1001`. Otherwise Discord will
438 /// continue showing the bot as online until its presence times out.
439 ///
440 /// To read all remaining messages, continue calling [`poll_next`] until it
441 /// returns [`Message::Close`].
442 ///
443 /// # Example
444 ///
445 /// Close the shard and process remaining messages:
446 ///
447 /// ```no_run
448 /// # use twilight_gateway::{Intents, Shard, ShardId};
449 /// # #[tokio::main] async fn main() {
450 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
451 /// use tokio_stream::StreamExt;
452 /// use twilight_gateway::{error::ReceiveMessageErrorType, CloseFrame, Message};
453 ///
454 /// shard.close(CloseFrame::NORMAL);
455 ///
456 /// while let Some(item) = shard.next().await {
457 /// match item {
458 /// Ok(Message::Close(_)) => break,
459 /// Ok(Message::Text(_)) => unimplemented!(),
460 /// Err(source) => unimplemented!(),
461 /// }
462 /// }
463 /// # }
464 /// ```
465 ///
466 /// [`poll_next`]: Shard::poll_next
467 pub fn close(&self, close_frame: CloseFrame<'static>) {
468 _ = self.user_channel.close_tx.try_send(close_frame);
469 }
470
471 /// Retrieve a channel to send messages over the shard to the gateway.
472 ///
473 /// This is primarily useful for sending to other tasks and threads where
474 /// the shard won't be available.
475 ///
476 /// # Example
477 ///
478 /// Queue a command in another process:
479 ///
480 /// ```no_run
481 /// # use twilight_gateway::{Intents, Shard, ShardId};
482 /// # #[tokio::main] async fn main() {
483 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
484 /// use tokio_stream::StreamExt;
485 ///
486 /// while let Some(item) = shard.next().await {
487 /// match item {
488 /// Ok(message) => {
489 /// let sender = shard.sender();
490 /// tokio::spawn(async move {
491 /// let command = unimplemented!();
492 /// sender.send(command);
493 /// });
494 /// }
495 /// Err(source) => unimplemented!(),
496 /// }
497 /// }
498 /// # }
499 /// ```
500 pub fn sender(&self) -> MessageSender {
501 self.user_channel.sender()
502 }
503
504 /// Update internal state from gateway disconnect.
505 fn disconnect(&mut self, initiator: CloseInitiator) {
506 // May not send any additional WebSocket messages.
507 self.heartbeat_interval = None;
508 self.ratelimiter = None;
509 // Abort identify.
510 self.identify_rx = None;
511 self.state = match initiator {
512 CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
513 _ => ShardState::Disconnected {
514 reconnect_attempts: 0,
515 },
516 };
517 if let CloseInitiator::Shard(frame) = initiator {
518 // Not resuming, drop session and resume URL.
519 // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
520 if matches!(frame.code, 1000 | 1001) {
521 self.resume_url = None;
522 self.session = None;
523 }
524 self.pending = Some(Pending {
525 gateway_event: Some(Message::Close(Some(frame))),
526 is_heartbeat: false,
527 });
528 }
529 }
530
531 /// Parse a JSON message into an event with minimal data for [processing].
532 ///
533 /// # Errors
534 ///
535 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
536 /// event isn't a recognized structure, which may be the case for new or
537 /// undocumented events.
538 ///
539 /// [processing]: Self::process
540 fn parse_event<T: DeserializeOwned>(
541 json: &str,
542 ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
543 json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
544 kind: ReceiveMessageErrorType::Deserializing {
545 event: json.to_owned(),
546 },
547 source: Some(Box::new(source)),
548 })
549 }
550}
551
552impl<Q: Queue> Shard<Q> {
553 /// Attempts to send due commands to the gateway.
554 ///
555 /// # Returns
556 ///
557 /// * `Poll::Pending` if sending is in progress
558 /// * `Poll::Ready(Ok)` if no more scheduled commands remain
559 /// * `Poll::Ready(Err)` if sending a command failed.
560 fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
561 loop {
562 if let Some(pending) = self.pending.as_mut() {
563 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
564
565 if let Some(message) = &pending.gateway_event {
566 if let Some(ratelimiter) = self.ratelimiter.as_mut() {
567 if message.is_text() && !pending.is_heartbeat {
568 ready!(ratelimiter.poll_acquire(cx));
569 }
570 }
571
572 let ws_message = pending.gateway_event.take().unwrap().into_websocket_msg();
573 Pin::new(self.connection.as_mut().unwrap()).start_send(ws_message)?;
574 }
575
576 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
577
578 if pending.is_heartbeat {
579 self.latency.record_sent();
580 }
581 self.pending = None;
582 }
583
584 if !self.state.is_disconnected() {
585 if let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx) {
586 let frame = frame.expect("shard owns channel");
587
588 tracing::debug!("sending close frame from user channel");
589 self.disconnect(CloseInitiator::Shard(frame));
590
591 continue;
592 }
593 }
594
595 if self
596 .heartbeat_interval
597 .as_mut()
598 .is_some_and(|heartbeater| heartbeater.poll_tick(cx).is_ready())
599 {
600 // Discord never responded after the last heartbeat, connection
601 // is failed or "zombied", see
602 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
603 // Note that unlike documented *any* event is okay; it does not
604 // have to be a heartbeat ACK.
605 if self.latency.sent().is_some() && !self.heartbeat_interval_event {
606 tracing::info!("connection is failed or \"zombied\"");
607 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
608 } else {
609 tracing::debug!("sending heartbeat");
610 self.pending = Pending::text(
611 json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
612 .expect("serialization cannot fail"),
613 true,
614 );
615 self.heartbeat_interval_event = false;
616 }
617
618 continue;
619 }
620
621 let not_ratelimited = self.ratelimiter.as_mut().map_or(true, |ratelimiter| {
622 ratelimiter.poll_available(cx).is_ready()
623 });
624
625 if not_ratelimited {
626 if let Some(Poll::Ready(canceled)) = self
627 .identify_rx
628 .as_mut()
629 .map(|rx| Pin::new(rx).poll(cx).map(|r| r.is_err()))
630 {
631 if canceled {
632 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
633 continue;
634 }
635
636 tracing::debug!("sending identify");
637
638 self.pending = Pending::text(
639 json::to_string(&Identify::new(IdentifyInfo {
640 compress: false,
641 intents: self.config.intents(),
642 large_threshold: self.config.large_threshold(),
643 presence: self.config.presence().cloned(),
644 properties: self
645 .config
646 .identify_properties()
647 .cloned()
648 .unwrap_or_else(default_identify_properties),
649 shard: Some(self.id),
650 token: self.config.token().to_owned(),
651 }))
652 .expect("serialization cannot fail"),
653 false,
654 );
655 self.identify_rx = None;
656
657 continue;
658 }
659 }
660
661 if not_ratelimited && self.state.is_identified() {
662 if let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx) {
663 let command = command.expect("shard owns channel");
664
665 tracing::debug!("sending command from user channel");
666 self.pending = Some(Pending {
667 gateway_event: Some(Message::Text(command)),
668 is_heartbeat: false,
669 });
670
671 continue;
672 }
673 }
674
675 return Poll::Ready(Ok(()));
676 }
677 }
678
679 /// Updates the shard's internal state from a gateway event by recording
680 /// and/or responding to certain Discord events.
681 ///
682 /// # Errors
683 ///
684 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
685 /// gateway event isn't a recognized structure.
686 #[allow(clippy::too_many_lines)]
687 fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
688 let (raw_opcode, maybe_sequence, maybe_event_type) =
689 GatewayEventDeserializer::from_json(event)
690 .ok_or_else(|| ReceiveMessageError {
691 kind: ReceiveMessageErrorType::Deserializing {
692 event: event.to_owned(),
693 },
694 source: Some("missing opcode".into()),
695 })?
696 .into_parts();
697
698 if self.latency.sent().is_some() {
699 self.heartbeat_interval_event = true;
700 }
701
702 match OpCode::from(raw_opcode) {
703 Some(OpCode::Dispatch) => {
704 let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
705 kind: ReceiveMessageErrorType::Deserializing {
706 event: event.to_owned(),
707 },
708 source: Some("missing dispatch event type".into()),
709 })?;
710 let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
711 kind: ReceiveMessageErrorType::Deserializing {
712 event: event.to_owned(),
713 },
714 source: Some("missing sequence".into()),
715 })?;
716 tracing::debug!(%event_type, %sequence, "received dispatch");
717
718 match event_type.as_ref() {
719 "READY" => {
720 let event = Self::parse_event::<MinimalReady>(event)?;
721
722 self.resume_url = Some(event.data.resume_gateway_url);
723 self.session = Some(Session::new(sequence, event.data.session_id));
724 self.state = ShardState::Active;
725 }
726 "RESUMED" => self.state = ShardState::Active,
727 _ => {}
728 }
729
730 if let Some(session) = self.session.as_mut() {
731 session.set_sequence(sequence);
732 }
733 }
734 Some(OpCode::Heartbeat) => {
735 tracing::debug!("received heartbeat");
736 self.pending = Pending::text(
737 json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
738 .expect("serialization cannot fail"),
739 true,
740 );
741 }
742 Some(OpCode::HeartbeatAck) => {
743 let requested = self.latency.received().is_none() && self.latency.sent().is_some();
744 if requested {
745 tracing::debug!("received heartbeat ack");
746 self.latency.record_received();
747 } else {
748 tracing::info!("received unrequested heartbeat ack");
749 }
750 }
751 Some(OpCode::Hello) => {
752 let event = Self::parse_event::<Hello>(event)?;
753 let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
754 // First heartbeat should have some jitter, see
755 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
756 let jitter = heartbeat_interval.mul_f64(fastrand::f64());
757 tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
758
759 if self.config().ratelimit_messages() {
760 self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
761 }
762
763 let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
764 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
765 self.heartbeat_interval = Some(interval);
766
767 // Reset `Latency` since the shard might have connected to a new
768 // remote which invalidates the recorded latencies.
769 self.latency = Latency::new();
770
771 if let Some(session) = &self.session {
772 self.pending = Pending::text(
773 json::to_string(&Resume::new(
774 session.sequence(),
775 session.id(),
776 self.config.token(),
777 ))
778 .expect("serialization cannot fail"),
779 false,
780 );
781 self.state = ShardState::Resuming;
782 } else {
783 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
784 }
785 }
786 Some(OpCode::InvalidSession) => {
787 let resumable = Self::parse_event(event)?.data;
788 tracing::debug!(resumable, "received invalid session");
789 if resumable {
790 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
791 } else {
792 self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
793 }
794 }
795 Some(OpCode::Reconnect) => {
796 tracing::debug!("received reconnect");
797 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
798 }
799 _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
800 }
801
802 Ok(())
803 }
804}
805
806impl<Q: Queue + Unpin> Stream for Shard<Q> {
807 type Item = Result<Message, ReceiveMessageError>;
808
809 #[tracing::instrument(fields(id = %self.id), name = "shard", skip_all)]
810 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
811 let message = loop {
812 match self.state {
813 ShardState::FatallyClosed => {
814 _ = ready!(Pin::new(
815 self.connection
816 .as_mut()
817 .expect("poll_next called after Poll::Ready(None)")
818 )
819 .poll_close(cx));
820 self.connection = None;
821 return Poll::Ready(None);
822 }
823 ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
824 if self.connection_future.is_none() {
825 let base_url = self
826 .resume_url
827 .as_deref()
828 .or_else(|| self.config.proxy_url())
829 .unwrap_or(GATEWAY_URL);
830 let uri = format!(
831 "{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}"
832 );
833
834 tracing::debug!(url = base_url, "connecting to gateway");
835
836 let tls = self.config.tls.clone();
837 self.connection_future = Some(ConnectionFuture(Box::pin(async move {
838 let secs = 2u8.saturating_pow(reconnect_attempts.into());
839 time::sleep(Duration::from_secs(secs.into())).await;
840
841 Ok(ClientBuilder::new()
842 .uri(&uri)
843 .expect("URL should be valid")
844 .limits(Limits::unlimited())
845 .connector(&tls)
846 .connect()
847 .await?
848 .0)
849 })));
850 }
851
852 let res =
853 ready!(Pin::new(&mut self.connection_future.as_mut().unwrap().0).poll(cx));
854 self.connection_future = None;
855 match res {
856 Ok(connection) => {
857 self.connection = Some(connection);
858 self.state = ShardState::Identifying;
859 #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
860 self.inflater.reset();
861 }
862 Err(source) => {
863 self.resume_url = None;
864 self.state = ShardState::Disconnected {
865 reconnect_attempts: reconnect_attempts + 1,
866 };
867
868 return Poll::Ready(Some(Err(ReceiveMessageError {
869 kind: ReceiveMessageErrorType::Reconnect,
870 source: Some(Box::new(source)),
871 })));
872 }
873 }
874 }
875 _ => {}
876 }
877
878 if ready!(self.poll_send(cx)).is_err() {
879 self.disconnect(CloseInitiator::Transport);
880 self.connection = None;
881
882 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
883 }
884
885 match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
886 Some(Ok(message)) => {
887 #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
888 if message.is_binary() {
889 if let Some(decompressed) = self
890 .inflater
891 .inflate(message.as_payload())
892 .map_err(ReceiveMessageError::from_compression)?
893 {
894 break Message::Text(decompressed);
895 };
896 }
897 if let Some(message) = Message::from_websocket_msg(&message) {
898 break message;
899 }
900 }
901 // Discord, against recommendations from the WebSocket spec,
902 // does not send a close_notify prior to shutting down the TCP
903 // stream. This arm tries to gracefully handle this. The
904 // connection is considered unusable after encountering an io
905 // error, returning `None`.
906 #[cfg(any(
907 feature = "native-tls",
908 feature = "rustls-native-roots",
909 feature = "rustls-platform-verifier",
910 feature = "rustls-webpki-roots"
911 ))]
912 Some(Err(WebsocketError::Io(e)))
913 if e.kind() == IoErrorKind::UnexpectedEof
914 && self.config.proxy_url().is_none()
915 && self.state.is_disconnected() =>
916 {
917 continue
918 }
919 Some(Err(_)) => {
920 self.disconnect(CloseInitiator::Transport);
921 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
922 }
923 None => {
924 _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
925 tracing::debug!("gateway WebSocket connection closed");
926 // Unclean closure.
927 if !self.state.is_disconnected() {
928 self.disconnect(CloseInitiator::Transport);
929 }
930 self.connection = None;
931 }
932 }
933 };
934
935 match &message {
936 Message::Close(frame) => {
937 // tokio-websockets automatically replies to the close message.
938 tracing::debug!(?frame, "received WebSocket close message");
939 // Don't run `disconnect` if we initiated the close.
940 if !self.state.is_disconnected() {
941 self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
942 }
943 }
944 Message::Text(event) => {
945 self.process(event)?;
946 }
947 }
948
949 Poll::Ready(Some(Ok(message)))
950 }
951}
952
953/// Default identify properties to use when the user hasn't customized it in
954/// [`Config::identify_properties`].
955///
956/// [`Config::identify_properties`]: Config::identify_properties
957fn default_identify_properties() -> IdentifyProperties {
958 IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
959}
960
961#[cfg(test)]
962mod tests {
963 use super::Shard;
964 use static_assertions::{assert_impl_all, assert_not_impl_any};
965 use std::fmt::Debug;
966
967 assert_impl_all!(Shard: Debug, Send);
968 assert_not_impl_any!(Shard: Sync);
969}