1use crate::error::GatewayError;
8use crate::event::Event;
9use crate::ratelimit::IdentifyRateLimiter;
10use crate::shard::{Shard, ShardConfig, ShardState};
11
12use dashmap::DashMap;
13use flume::{Receiver, Sender};
14use std::sync::Arc;
15use titanium_model::Intents;
16use tokio::task::JoinHandle;
17use tracing::{error, info};
18
19#[derive(Debug, Clone)]
21pub enum ShardRange {
22 All {
24 total: u16,
26 },
27
28 Range {
30 start: u16,
32 end: u16,
34 total: u16,
36 },
37
38 Specific {
40 ids: Vec<u16>,
42 total: u16,
44 },
45}
46
47impl ShardRange {
48 pub fn shard_ids(&self) -> Vec<u16> {
50 match self {
51 ShardRange::All { total } => (0..*total).collect(),
52 ShardRange::Range { start, end, .. } => (*start..*end).collect(),
53 ShardRange::Specific { ids, .. } => ids.clone(),
54 }
55 }
56
57 pub fn total_shards(&self) -> u16 {
59 match self {
60 ShardRange::All { total } => *total,
61 ShardRange::Range { total, .. } => *total,
62 ShardRange::Specific { total, .. } => *total,
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct ClusterConfig {
70 pub token: String,
72
73 pub intents: Intents,
75
76 pub shard_range: ShardRange,
78
79 pub gateway_url: String,
81
82 pub max_concurrency: usize,
84
85 pub large_threshold: u8,
87}
88
89impl ClusterConfig {
90 pub fn new(token: impl Into<String>, intents: Intents, shard_range: ShardRange) -> Self {
92 Self {
93 token: token.into(),
94 intents,
95 shard_range,
96 gateway_url: crate::DEFAULT_GATEWAY_URL.to_string(),
97 max_concurrency: 1,
98 large_threshold: 250,
99 }
100 }
101
102 pub fn with_max_concurrency(mut self, max_concurrency: usize) -> Self {
104 self.max_concurrency = max_concurrency;
105 self
106 }
107
108 pub fn with_gateway_url(mut self, url: impl Into<String>) -> Self {
110 self.gateway_url = url.into();
111 self
112 }
113
114 #[cfg(feature = "auto-sharding")]
118 pub async fn autoscaled(
119 token: impl Into<String>,
120 intents: titanium_model::Intents,
121 ) -> Result<Self, crate::error::GatewayError> {
122 use titanium_http::HttpClient;
123
124 let token = token.into();
125 let client = HttpClient::new(&token).map_err(|_| crate::error::GatewayError::Closed {
126 code: 0,
127 reason: "Failed to create HTTP client for auto-sharding".into(),
128 })?;
129
130 let info =
131 client
132 .get_gateway_bot()
133 .await
134 .map_err(|e| crate::error::GatewayError::Closed {
135 code: 0,
136 reason: format!("Failed to fetch gateway info: {}", e),
137 })?;
138
139 Ok(Self {
140 token,
141 intents,
142 shard_range: ShardRange::All { total: info.shards },
143 gateway_url: info.url,
144 max_concurrency: info.session_start_limit.max_concurrency as usize,
145 large_threshold: 250,
146 })
147 }
148}
149
150struct ShardRunner {
152 shard: Arc<Shard>,
154 handle: JoinHandle<Result<(), GatewayError>>,
156}
157
158pub struct Cluster {
190 config: ClusterConfig,
192
193 shards: DashMap<u16, ShardRunner>,
195
196 rate_limiter: Arc<IdentifyRateLimiter>,
198
199 event_tx: Sender<(u16, Event<'static>)>,
201}
202
203impl Cluster {
204 pub fn new(config: ClusterConfig) -> (Self, Receiver<(u16, Event<'static>)>) {
209 let (event_tx, event_rx) = flume::unbounded();
210 let rate_limiter = Arc::new(IdentifyRateLimiter::new(config.max_concurrency));
211
212 let cluster = Self {
213 config,
214 shards: DashMap::new(),
215 rate_limiter,
216 event_tx,
217 };
218
219 (cluster, event_rx)
220 }
221
222 pub async fn start(&self) -> Result<(), GatewayError> {
227 let shard_ids = self.config.shard_range.shard_ids();
228 let total_shards = self.config.shard_range.total_shards();
229
230 info!(
231 shards = ?shard_ids,
232 total = total_shards,
233 max_concurrency = self.config.max_concurrency,
234 "Starting cluster"
235 );
236
237 for shard_id in shard_ids {
238 self.spawn_shard(shard_id, total_shards)?;
239 }
240
241 Ok(())
242 }
243
244 fn spawn_shard(&self, shard_id: u16, total_shards: u16) -> Result<(), GatewayError> {
246 let shard_config = ShardConfig {
247 token: self.config.token.clone(),
248 intents: self.config.intents,
249 gateway_url: self.config.gateway_url.clone(),
250 large_threshold: self.config.large_threshold,
251 compress: false,
252 max_reconnect_attempts: 10,
253 reconnect_base_delay_ms: 1000,
254 reconnect_max_delay_ms: 60000,
255 };
256
257 let shard = Arc::new(Shard::with_rate_limiter(
258 shard_id,
259 total_shards,
260 shard_config,
261 self.rate_limiter.clone(),
262 ));
263
264 let (shard_tx, shard_rx) = flume::unbounded::<Event>();
266 let cluster_tx = self.event_tx.clone();
267 let shard_id_for_forward = shard_id;
268
269 tokio::spawn(async move {
271 while let Ok(event) = shard_rx.recv_async().await {
272 if cluster_tx
273 .send_async((shard_id_for_forward, event))
274 .await
275 .is_err()
276 {
277 break;
278 }
279 }
280 });
281
282 let shard_clone = shard.clone();
284 let handle = tokio::spawn(async move { shard_clone.run(shard_tx).await });
285
286 self.shards.insert(shard_id, ShardRunner { shard, handle });
287
288 info!(shard_id = shard_id, "Shard spawned");
289 Ok(())
290 }
291
292 pub fn shard_state(&self, shard_id: u16) -> Option<ShardState> {
294 self.shards.get(&shard_id).map(|r| r.shard.state())
295 }
296
297 pub fn shard_latency(&self, shard_id: u16) -> Option<std::time::Duration> {
299 self.shards.get(&shard_id).and_then(|r| r.shard.latency())
300 }
301
302 pub fn shard_ids(&self) -> Vec<u16> {
304 self.shards.iter().map(|r| *r.key()).collect()
305 }
306
307 pub fn send(&self, shard_id: u16, payload: serde_json::Value) -> Result<(), GatewayError> {
309 if let Some(runner) = self.shards.get(&shard_id) {
310 runner.shard.send_payload(&payload)
311 } else {
312 Err(GatewayError::Closed {
313 code: 0,
314 reason: format!("Shard {} not found", shard_id),
315 })
316 }
317 }
318
319 pub async fn shutdown(&self) {
321 info!("Shutting down cluster");
322
323 for shard in self.shards.iter() {
325 shard.shard.shutdown();
326 }
327
328 for mut entry in self.shards.iter_mut() {
330 let runner = entry.value_mut();
331 if let Err(e) = (&mut runner.handle).await {
332 error!(shard_id = *entry.key(), error = %e, "Shard task panicked");
333 }
334 }
335
336 info!("Cluster shutdown complete");
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_shard_range_all() {
346 let range = ShardRange::All { total: 10 };
347 let ids = range.shard_ids();
348 assert_eq!(ids.len(), 10);
349 assert_eq!(ids[0], 0);
350 assert_eq!(ids[9], 9);
351 }
352
353 #[test]
354 fn test_shard_range_specific() {
355 let range = ShardRange::Specific {
356 ids: vec![0, 5, 10],
357 total: 20,
358 };
359 let ids = range.shard_ids();
360 assert_eq!(ids, vec![0, 5, 10]);
361 assert_eq!(range.total_shards(), 20);
362 }
363
364 #[test]
365 fn test_cluster_config() {
366 let config = ClusterConfig::new(
367 "test_token",
368 Intents::GUILDS,
369 ShardRange::Range {
370 start: 0,
371 end: 5,
372 total: 10,
373 },
374 )
375 .with_max_concurrency(16);
376
377 assert_eq!(config.max_concurrency, 16);
378 assert_eq!(config.shard_range.shard_ids().len(), 5);
379 }
380}