1use crate::compression::ZlibDecompressor;
7use crate::error::{CloseCode, GatewayError};
8use crate::event::{parse_event, Event, ReadyEventData};
9use crate::heartbeat::HeartbeatHandler;
10use crate::opcode::OpCode;
11use crate::payload::{
12 create_heartbeat_payload, GatewayPayload, HelloPayload, IdentifyPayload, RawGatewayPayload,
13 ResumePayload,
14};
15use crate::ratelimit::{exponential_backoff, with_jitter, IdentifyRateLimiter};
16use crate::{DEFAULT_GATEWAY_URL, GATEWAY_VERSION};
17
18use flume::Sender;
19use futures_util::{SinkExt, StreamExt};
20use parking_lot::RwLock;
21#[cfg(feature = "simd")]
22use simd_json::prelude::*;
23use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26use tokio::net::TcpStream;
27use tokio::time::{sleep, timeout};
28use tokio_tungstenite::tungstenite::protocol::CloseFrame;
29use tokio_tungstenite::tungstenite::Message as WsMessage;
30use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
31use tracing::{debug, error, info, trace, warn};
32use url::Url;
33
34#[derive(Debug)]
36enum ShardCommand {
37 Send(String),
39}
40
41enum GatewayAction {
43 Dispatch(Event<'static>),
44 Heartbeat,
45 Reconnect,
46 InvalidSession(bool),
47 None,
48}
49
50type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum ShardState {
56 Disconnected,
58 Connecting,
60 Handshaking,
62 Identifying,
64 Resuming,
66 Connected,
68 Reconnecting,
70 Disconnecting,
72}
73
74#[derive(Debug, Clone)]
76pub struct ShardConfig {
77 pub token: String,
79
80 pub intents: titanium_model::Intents,
82
83 pub gateway_url: String,
85
86 pub large_threshold: u8,
88
89 pub compress: bool,
91
92 pub max_reconnect_attempts: u32,
94
95 pub reconnect_base_delay_ms: u64,
97
98 pub reconnect_max_delay_ms: u64,
100}
101
102impl ShardConfig {
103 pub fn new(token: impl Into<String>, intents: titanium_model::Intents) -> Self {
105 Self {
106 token: token.into(),
107 intents,
108 gateway_url: DEFAULT_GATEWAY_URL.to_string(),
109 large_threshold: 250,
110 compress: false,
111 max_reconnect_attempts: 10,
112 reconnect_base_delay_ms: 1000,
113 reconnect_max_delay_ms: 60000,
114 }
115 }
116
117 pub fn with_gateway_url(mut self, url: impl Into<String>) -> Self {
119 self.gateway_url = url.into();
120 self
121 }
122}
123
124#[derive(Debug, Clone)]
126struct SessionData {
127 session_id: String,
129 resume_url: String,
131}
132
133pub struct Shard {
138 shard_id: u16,
143
144 total_shards: u16,
146
147 config: ShardConfig,
152
153 rate_limiter: Arc<IdentifyRateLimiter>,
155
156 state: RwLock<ShardState>,
161
162 session: RwLock<Option<SessionData>>,
164
165 sequence: AtomicU64,
167
168 heartbeat: HeartbeatHandler,
170
171 decompressor: RwLock<ZlibDecompressor>,
173
174 shutdown: AtomicBool,
176
177 command_tx: Sender<ShardCommand>,
179
180 command_rx: flume::Receiver<ShardCommand>,
182}
183
184impl Shard {
185 pub fn new(shard_id: u16, total_shards: u16, config: ShardConfig) -> Self {
192 Self::with_rate_limiter(
193 shard_id,
194 total_shards,
195 config,
196 Arc::new(IdentifyRateLimiter::default()),
197 )
198 }
199
200 pub fn with_rate_limiter(
202 shard_id: u16,
203 total_shards: u16,
204 config: ShardConfig,
205 rate_limiter: Arc<IdentifyRateLimiter>,
206 ) -> Self {
207 let (tx, rx) = flume::unbounded();
208
209 Self {
210 shard_id,
211 total_shards,
212 config,
213 rate_limiter,
214 state: RwLock::new(ShardState::Disconnected),
215 session: RwLock::new(None),
216 sequence: AtomicU64::new(0),
217 heartbeat: HeartbeatHandler::new(),
218 decompressor: RwLock::new(ZlibDecompressor::new()),
219 shutdown: AtomicBool::new(false),
220 command_tx: tx,
221 command_rx: rx,
222 }
223 }
224
225 pub fn shard_id(&self) -> u16 {
227 self.shard_id
228 }
229
230 pub fn total_shards(&self) -> u16 {
232 self.total_shards
233 }
234
235 pub fn state(&self) -> ShardState {
237 *self.state.read()
238 }
239
240 pub fn sequence(&self) -> u64 {
242 self.sequence.load(Ordering::SeqCst)
243 }
244
245 pub fn shutdown(&self) {
247 self.shutdown.store(true, Ordering::SeqCst);
248 }
249
250 pub fn latency(&self) -> Option<Duration> {
252 self.heartbeat.latency()
253 }
254
255 pub fn send_payload<T: serde::Serialize>(&self, payload: &T) -> Result<(), GatewayError> {
260 #[cfg(feature = "simd")]
261 let json = simd_json::to_string(payload).map_err(|e| GatewayError::Closed {
262 code: 0,
263 reason: format!("Serialization error: {}", e),
264 })?;
265
266 #[cfg(not(feature = "simd"))]
267 let json = serde_json::to_string(payload)?;
268
269 self.command_tx
270 .send(ShardCommand::Send(json))
271 .map_err(|_| GatewayError::Closed {
272 code: 0,
273 reason: "Shard command channel closed".to_string(),
274 })
275 }
276
277 pub async fn run(&self, event_tx: Sender<Event<'static>>) -> Result<(), GatewayError> {
288 let mut reconnect_attempts = 0u32;
289 let mut read_buffer = Vec::with_capacity(32 * 1024);
290
291 loop {
292 if self.shutdown.load(Ordering::SeqCst) {
294 info!(shard_id = self.shard_id, "Shard shutdown requested");
295 *self.state.write() = ShardState::Disconnecting;
296 return Ok(());
297 }
298
299 match self.connect_and_run(&event_tx, &mut read_buffer).await {
301 Ok(()) => {
302 return Ok(());
304 }
305 Err(GatewayError::HeartbeatTimeout) => {
306 warn!(
307 shard_id = self.shard_id,
308 "Heartbeat timeout, reconnecting..."
309 );
310 reconnect_attempts += 1;
311 }
312 Err(GatewayError::InvalidSession { resumable }) => {
313 if !resumable {
314 *self.session.write() = None;
316 self.sequence.store(0, Ordering::SeqCst);
317 }
318 warn!(
319 shard_id = self.shard_id,
320 resumable = resumable,
321 "Session invalidated, reconnecting..."
322 );
323 reconnect_attempts += 1;
324 }
325 Err(GatewayError::Closed { code, reason }) => {
326 let close_code = CloseCode::from_code(code);
327
328 if let Some(cc) = close_code {
329 if !cc.can_reconnect() {
330 error!(
331 shard_id = self.shard_id,
332 code = code,
333 reason = %reason,
334 "Fatal close code, cannot reconnect"
335 );
336 return Err(GatewayError::Closed { code, reason });
337 }
338 }
339
340 warn!(
341 shard_id = self.shard_id,
342 code = code,
343 reason = %reason,
344 "Connection closed, reconnecting..."
345 );
346 reconnect_attempts += 1;
347 }
348 Err(e) => {
349 error!(shard_id = self.shard_id, error = %e, "Shard error");
350 reconnect_attempts += 1;
351 }
352 }
353
354 if reconnect_attempts > self.config.max_reconnect_attempts {
356 error!(
357 shard_id = self.shard_id,
358 attempts = reconnect_attempts,
359 "Max reconnect attempts exceeded"
360 );
361 return Err(GatewayError::Closed {
362 code: 0,
363 reason: "Max reconnect attempts exceeded".to_string(),
364 });
365 }
366
367 let backoff = exponential_backoff(
369 reconnect_attempts - 1,
370 self.config.reconnect_base_delay_ms,
371 self.config.reconnect_max_delay_ms,
372 );
373 let backoff_with_jitter = with_jitter(backoff, 0.25);
374
375 info!(
376 shard_id = self.shard_id,
377 attempt = reconnect_attempts,
378 backoff_ms = backoff_with_jitter.as_millis(),
379 "Waiting before reconnect"
380 );
381
382 *self.state.write() = ShardState::Reconnecting;
383 sleep(backoff_with_jitter).await;
384 }
385 }
386
387 async fn connect_and_run(
389 &self,
390 event_tx: &Sender<Event<'static>>,
391 buffer: &mut Vec<u8>,
392 ) -> Result<(), GatewayError> {
393 let gateway_url = self.build_gateway_url()?;
395
396 info!(shard_id = self.shard_id, url = %gateway_url, "Connecting to Gateway");
397 *self.state.write() = ShardState::Connecting;
398
399 let (ws_stream, _response) = connect_async(gateway_url.as_str()).await?;
401 let (mut sink, mut stream) = ws_stream.split();
402
403 info!(shard_id = self.shard_id, "WebSocket connected");
404 *self.state.write() = ShardState::Handshaking;
405
406 let hello = self.wait_for_hello(&mut stream).await?;
408 let heartbeat_interval = Duration::from_millis(hello.heartbeat_interval);
409 self.heartbeat.set_interval(heartbeat_interval);
410
411 debug!(
412 shard_id = self.shard_id,
413 interval_ms = hello.heartbeat_interval,
414 "Received Hello"
415 );
416
417 self.rate_limiter.acquire().await;
419
420 let session = self.session.read().clone();
421 if let Some(ref session_data) = session {
422 *self.state.write() = ShardState::Resuming;
424 info!(
425 shard_id = self.shard_id,
426 session_id = %session_data.session_id,
427 "Resuming session"
428 );
429 self.send_resume(&mut sink, session_data).await?;
430 } else {
431 *self.state.write() = ShardState::Identifying;
433 info!(shard_id = self.shard_id, "Sending Identify");
434 self.send_identify(&mut sink).await?;
435 }
436
437 self.heartbeat.reset();
439
440 self.send_heartbeat(&mut sink).await?;
442 self.heartbeat.mark_sent();
443
444 let mut next_heartbeat = Instant::now() + heartbeat_interval;
446
447 loop {
449 if self.shutdown.load(Ordering::SeqCst) {
451 let _ = sink.close().await;
453 return Ok(());
454 }
455
456 tokio::select! {
457 message = stream.next() => {
459 match message {
460 Some(Ok(msg)) => {
461 self.handle_message(msg, event_tx, &mut sink, buffer).await?;
462 }
463 Some(Err(e)) => {
464 return Err(GatewayError::WebSocket(e));
465 }
466 None => {
467 return Err(GatewayError::Closed {
469 code: 0,
470 reason: "WebSocket stream ended".to_string(),
471 });
472 }
473 }
474 }
475
476 _ = sleep(next_heartbeat.saturating_duration_since(Instant::now())) => {
478 if !self.heartbeat.is_acked() {
480 error!(shard_id = self.shard_id, "No heartbeat ACK received, assuming zombie connection");
481 return Err(GatewayError::HeartbeatTimeout);
482 }
483
484 self.send_heartbeat(&mut sink).await?;
486 self.heartbeat.mark_sent();
487
488 next_heartbeat = Instant::now() + self.heartbeat.interval();
490 }
491
492 command = self.command_rx.recv_async() => {
494 match command {
495 Ok(ShardCommand::Send(json)) => {
496 trace!(shard_id = self.shard_id, "Sending custom payload");
497 sink.send(WsMessage::Text(json.into())).await?;
498 }
499 Err(_) => {
500 return Err(GatewayError::Closed {
502 code: 0,
503 reason: "Command channel closed".to_string(),
504 });
505 }
506 }
507 }
508 }
509 }
510 }
511
512 fn build_gateway_url(&self) -> Result<Url, GatewayError> {
514 let base_url = self
516 .session
517 .read()
518 .as_ref()
519 .map(|s| s.resume_url.clone())
520 .unwrap_or_else(|| self.config.gateway_url.clone());
521
522 let mut url = Url::parse(&base_url).map_err(|e| GatewayError::Closed {
523 code: 0,
524 reason: format!("Invalid URL: {}", e),
525 })?;
526
527 url.query_pairs_mut()
529 .append_pair("v", &GATEWAY_VERSION.to_string())
530 .append_pair("encoding", "json");
531
532 if self.config.compress {
533 url.query_pairs_mut().append_pair("compress", "zlib-stream");
534 }
535
536 Ok(url)
537 }
538
539 async fn wait_for_hello(
541 &self,
542 stream: &mut futures_util::stream::SplitStream<WsStream>,
543 ) -> Result<HelloPayload, GatewayError> {
544 let hello_timeout = Duration::from_secs(10);
546
547 let message = timeout(hello_timeout, stream.next())
548 .await
549 .map_err(|_| GatewayError::Closed {
550 code: 0,
551 reason: "Timeout waiting for Hello".to_string(),
552 })?
553 .ok_or_else(|| GatewayError::Closed {
554 code: 0,
555 reason: "Connection closed before Hello".to_string(),
556 })??;
557
558 if let WsMessage::Text(text) = message {
559 let payload: RawGatewayPayload = serde_json::from_str(&text)?;
560
561 if payload.op == OpCode::Hello {
562 if let Some(data) = payload.d {
563 #[cfg(feature = "simd")]
564 let hello: HelloPayload = titanium_model::json::from_value(data)?;
565 #[cfg(not(feature = "simd"))]
566 let hello: HelloPayload = serde_json::from_str(data.get())?;
567 return Ok(hello);
568 }
569 }
570 }
571
572 Err(GatewayError::Closed {
573 code: 0,
574 reason: "Expected Hello payload".to_string(),
575 })
576 }
577
578 async fn send_identify(
580 &self,
581 sink: &mut futures_util::stream::SplitSink<WsStream, WsMessage>,
582 ) -> Result<(), GatewayError> {
583 let identify = IdentifyPayload::new(
584 std::borrow::Cow::Borrowed(self.config.token.as_str()),
585 self.config.intents,
586 )
587 .with_shard(self.shard_id, self.total_shards);
588
589 let payload = GatewayPayload::new(OpCode::Identify, identify);
590
591 #[cfg(feature = "simd")]
592 let json = simd_json::to_string(&payload).map_err(|e| GatewayError::Closed {
593 code: 0,
594 reason: e.to_string(),
595 })?;
596
597 #[cfg(not(feature = "simd"))]
598 let json = serde_json::to_string(&payload)?;
599
600 trace!(shard_id = self.shard_id, "Sending Identify payload");
601 sink.send(WsMessage::Text(json.into())).await?;
602
603 Ok(())
604 }
605
606 async fn send_resume(
608 &self,
609 sink: &mut futures_util::stream::SplitSink<WsStream, WsMessage>,
610 session: &SessionData,
611 ) -> Result<(), GatewayError> {
612 let resume = ResumePayload {
613 token: std::borrow::Cow::Borrowed(self.config.token.as_str()),
614 session_id: std::borrow::Cow::Borrowed(session.session_id.as_str()),
615 seq: self.sequence.load(Ordering::SeqCst),
616 };
617
618 let payload = GatewayPayload::new(OpCode::Resume, resume);
619
620 #[cfg(feature = "simd")]
621 let json = simd_json::to_string(&payload).map_err(|e| GatewayError::Closed {
622 code: 0,
623 reason: e.to_string(),
624 })?;
625
626 #[cfg(not(feature = "simd"))]
627 let json = serde_json::to_string(&payload)?;
628
629 trace!(shard_id = self.shard_id, "Sending Resume payload");
630 sink.send(WsMessage::Text(json.into())).await?;
631
632 Ok(())
633 }
634
635 async fn send_heartbeat(
637 &self,
638 sink: &mut futures_util::stream::SplitSink<WsStream, WsMessage>,
639 ) -> Result<(), GatewayError> {
640 let seq = self.sequence.load(Ordering::SeqCst);
641 let seq_opt = if seq > 0 { Some(seq) } else { None };
642
643 let json = create_heartbeat_payload(seq_opt);
644
645 trace!(shard_id = self.shard_id, seq = seq, "Sending Heartbeat");
646 sink.send(WsMessage::Text(json.into())).await?;
647
648 Ok(())
649 }
650
651 async fn handle_message(
653 &self,
654 message: WsMessage,
655 event_tx: &Sender<Event<'static>>,
656 sink: &mut futures_util::stream::SplitSink<WsStream, WsMessage>,
657 buffer: &mut Vec<u8>,
658 ) -> Result<(), GatewayError> {
659 let action = match message {
660 WsMessage::Text(text) => {
661 buffer.clear();
663 buffer.extend_from_slice(text.as_str().as_bytes());
664 self.process_frame(buffer)?
665 }
666 WsMessage::Binary(data) => {
667 let mut decompressor = self.decompressor.write();
670 match decompressor.push(&data) {
671 Ok(Some(msg)) => self.process_frame(msg)?,
672 Ok(None) => GatewayAction::None, Err(e) => {
674 return Err(GatewayError::JsonDecode(format!(
675 "Decompression error: {}",
676 e
677 )))
678 }
679 }
680 }
681 WsMessage::Close(frame) => {
682 let (code, reason) = frame
683 .map(|f: CloseFrame| (f.code.into(), f.reason.to_string()))
684 .unwrap_or((0, String::new()));
685
686 return Err(GatewayError::Closed { code, reason });
687 }
688 WsMessage::Ping(data) => {
689 sink.send(WsMessage::Pong(data)).await?;
690 return Ok(());
691 }
692 WsMessage::Pong(_) => return Ok(()),
693 WsMessage::Frame(_) => return Ok(()),
694 };
695
696 match action {
697 GatewayAction::Dispatch(event) => {
698 event_tx.send_async(event).await?;
699 }
700 GatewayAction::Heartbeat => {
701 debug!(shard_id = self.shard_id, "Received Heartbeat request");
702 self.send_heartbeat(sink).await?;
703 }
704 GatewayAction::Reconnect => {
705 info!(shard_id = self.shard_id, "Received Reconnect request");
706 return Err(GatewayError::Closed {
707 code: 0,
708 reason: "Server requested reconnect".to_string(),
709 });
710 }
711 GatewayAction::InvalidSession(resumable) => {
712 warn!(
713 shard_id = self.shard_id,
714 resumable = resumable,
715 "Session invalidated"
716 );
717 return Err(GatewayError::InvalidSession { resumable });
718 }
719 GatewayAction::None => {}
720 }
721
722 Ok(())
723 }
724
725 fn process_frame(&self, text: &mut [u8]) -> Result<GatewayAction, GatewayError> {
731 #[cfg(feature = "simd")]
732 {
733 let json = titanium_model::json::to_borrowed_value(text)
735 .map_err(|e| GatewayError::JsonDecode(e.to_string()))?;
736
737 if let Some(seq) = json["s"].as_u64() {
739 self.sequence.store(seq, Ordering::SeqCst);
740 }
741
742 let op_val = json["op"].clone();
744 let op: OpCode = titanium_model::json::from_borrowed_value(op_val)
745 .map_err(|e| GatewayError::JsonDecode(e.to_string()))?;
746
747 match op {
748 OpCode::Dispatch => {
749 let d_val = json["d"].clone();
750 if let Some(event_name) = json["t"].as_str() {
751 let event_result = parse_event(event_name, d_val)?;
752
753 if let Event::Ready(ref ready) = event_result {
754 self.handle_ready(ready);
755 }
756
757 return Ok(GatewayAction::Dispatch(event_result));
758 }
759 }
760
761 OpCode::Heartbeat => return Ok(GatewayAction::Heartbeat),
762 OpCode::Reconnect => return Ok(GatewayAction::Reconnect),
763
764 OpCode::InvalidSession => {
765 let resumable = json["d"].as_bool().unwrap_or(false);
766 return Ok(GatewayAction::InvalidSession(resumable));
767 }
768
769 OpCode::HeartbeatAck => {
770 self.heartbeat.mark_acked();
771 let rtt = self.heartbeat.latency().unwrap_or_default();
772 trace!(
773 shard_id = self.shard_id,
774 rtt_ms = rtt.as_millis(),
775 "Heartbeat ACK received"
776 );
777 }
778
779 _ => {
780 trace!(
781 shard_id = self.shard_id,
782 opcode = ?op,
783 "Ignoring opcode"
784 );
785 }
786 }
787 }
788
789 #[cfg(not(feature = "simd"))]
790 {
791 let payload: RawGatewayPayload = titanium_model::json::from_slice_mut(text)
792 .map_err(|e| GatewayError::JsonDecode(e.to_string()))?;
793
794 if let Some(seq) = payload.s {
795 self.sequence.store(seq, Ordering::SeqCst);
796 }
797
798 match payload.op {
799 OpCode::Dispatch => {
800 if let (Some(event_name), Some(data)) = (payload.t.as_deref(), payload.d) {
801 let event_name = event_name.to_string();
802 let json_string = data.get().to_string();
805
806 let raw_value = serde_json::value::RawValue::from_string(json_string)
807 .map_err(GatewayError::from)?;
808 let event_result = parse_event(&event_name, &raw_value)?;
809
810 if let Event::Ready(ref ready) = event_result {
811 self.handle_ready(ready);
812 }
813 return Ok(GatewayAction::Dispatch(event_result));
814 }
815 }
816 OpCode::Heartbeat => return Ok(GatewayAction::Heartbeat),
817 OpCode::Reconnect => return Ok(GatewayAction::Reconnect),
818 OpCode::InvalidSession => {
819 let resumable = payload.d.map(|d| d.get() == "true").unwrap_or(false);
823 return Ok(GatewayAction::InvalidSession(resumable));
824 }
825 OpCode::HeartbeatAck => {
826 self.heartbeat.mark_acked();
827 }
828 _ => {}
829 }
830 }
831
832 Ok(GatewayAction::None)
833 }
834
835
836 fn handle_ready(&self, ready: &ReadyEventData) {
838 *self.session.write() = Some(SessionData {
839 session_id: ready.session_id.clone(),
840 resume_url: ready.resume_gateway_url.clone(),
841 });
842 *self.state.write() = ShardState::Connected;
843
844 info!(
845 shard_id = self.shard_id,
846 session_id = %ready.session_id,
847 guilds = ready.guilds.len(),
848 "Shard connected"
849 );
850 }
851}
852
853#[cfg(test)]
854mod tests {
855 use super::*;
856 use titanium_model::Intents;
857
858 #[test]
859 fn test_shard_config() {
860 let config = ShardConfig::new("test_token", Intents::GUILDS | Intents::GUILD_MESSAGES);
861 assert_eq!(config.token, "test_token");
862 assert!(config.intents.contains(Intents::GUILDS));
863 }
864
865 #[test]
866 fn test_shard_creation() {
867 let config = ShardConfig::new("test_token", Intents::default());
868 let shard = Shard::new(0, 1, config);
869
870 assert_eq!(shard.shard_id(), 0);
871 assert_eq!(shard.total_shards(), 1);
872 assert_eq!(shard.state(), ShardState::Disconnected);
873 }
874
875 #[test]
876 fn test_gateway_url_building() {
877 let config = ShardConfig::new("test", Intents::default());
878 let shard = Shard::new(0, 1, config);
879
880 let url = shard.build_gateway_url().expect("Failed to build URL");
881 assert!(url.as_str().contains("v=10"));
882 assert!(url.as_str().contains("encoding=json"));
883 }
884}