titanium_gateway/
cluster.rs

1//! Cluster management for multi-shard deployments.
2//!
3//! A Cluster manages multiple Shards, distributing work across them.
4//! For very large bots (1M+ guilds), multiple Clusters can run on
5//! different machines with coordinated shard ranges.
6
7use crate::error::GatewayError;
8use crate::event::Event;
9use crate::ratelimit::IdentifyRateLimiter;
10use crate::shard::{Shard, ShardConfig, ShardState};
11
12use dashmap::DashMap;
13use flume::{Receiver, Sender};
14use std::sync::Arc;
15use titanium_model::Intents;
16use tokio::task::JoinHandle;
17use tracing::{error, info};
18
19/// Specifies which shards a Cluster should manage.
20#[derive(Debug, Clone)]
21pub enum ShardRange {
22    /// Manage all shards (total count from API or config).
23    All {
24        /// Total number of shards.
25        total: u16,
26    },
27
28    /// Manage a specific range of shards.
29    Range {
30        /// Starting shard ID (inclusive).
31        start: u16,
32        /// Ending shard ID (exclusive).
33        end: u16,
34        /// Total number of shards across all clusters.
35        total: u16,
36    },
37
38    /// Manage specific shard IDs.
39    Specific {
40        /// Shard IDs to manage.
41        ids: Vec<u16>,
42        /// Total number of shards across all clusters.
43        total: u16,
44    },
45}
46
47impl ShardRange {
48    /// Get the shard IDs this range covers.
49    pub fn shard_ids(&self) -> Vec<u16> {
50        match self {
51            ShardRange::All { total } => (0..*total).collect(),
52            ShardRange::Range { start, end, .. } => (*start..*end).collect(),
53            ShardRange::Specific { ids, .. } => ids.clone(),
54        }
55    }
56
57    /// Get the total number of shards.
58    pub fn total_shards(&self) -> u16 {
59        match self {
60            ShardRange::All { total } => *total,
61            ShardRange::Range { total, .. } => *total,
62            ShardRange::Specific { total, .. } => *total,
63        }
64    }
65}
66
67/// Configuration for a Cluster.
68#[derive(Debug, Clone)]
69pub struct ClusterConfig {
70    /// Bot token.
71    pub token: String,
72
73    /// Gateway intents.
74    pub intents: Intents,
75
76    /// Which shards to manage.
77    pub shard_range: ShardRange,
78
79    /// Gateway URL (usually from /gateway/bot).
80    pub gateway_url: String,
81
82    /// Maximum concurrent identify operations (from /gateway/bot).
83    pub max_concurrency: usize,
84
85    /// Large guild threshold.
86    pub large_threshold: u8,
87}
88
89impl ClusterConfig {
90    /// Create a new cluster configuration.
91    pub fn new(token: impl Into<String>, intents: Intents, shard_range: ShardRange) -> Self {
92        Self {
93            token: token.into(),
94            intents,
95            shard_range,
96            gateway_url: crate::DEFAULT_GATEWAY_URL.to_string(),
97            max_concurrency: 1,
98            large_threshold: 250,
99        }
100    }
101
102    /// Set the maximum concurrency (from /gateway/bot response).
103    pub fn with_max_concurrency(mut self, max_concurrency: usize) -> Self {
104        self.max_concurrency = max_concurrency;
105        self
106    }
107
108    /// Set the gateway URL.
109    pub fn with_gateway_url(mut self, url: impl Into<String>) -> Self {
110        self.gateway_url = url.into();
111        self
112    }
113
114    /// Create a new cluster configuration with auto-detected shard count.
115    ///
116    /// This requires the `auto-sharding` feature.
117    #[cfg(feature = "auto-sharding")]
118    pub async fn autoscaled(
119        token: impl Into<String>,
120        intents: titanium_model::Intents,
121    ) -> Result<Self, crate::error::GatewayError> {
122        use titanium_http::HttpClient;
123
124        let token = token.into();
125        let client = HttpClient::new(&token).map_err(|_| crate::error::GatewayError::Closed {
126            code: 0,
127            reason: "Failed to create HTTP client for auto-sharding".into(),
128        })?;
129
130        let info =
131            client
132                .get_gateway_bot()
133                .await
134                .map_err(|e| crate::error::GatewayError::Closed {
135                    code: 0,
136                    reason: format!("Failed to fetch gateway info: {}", e),
137                })?;
138
139        Ok(Self {
140            token,
141            intents,
142            shard_range: ShardRange::All { total: info.shards },
143            gateway_url: info.url,
144            max_concurrency: info.session_start_limit.max_concurrency as usize,
145            large_threshold: 250,
146        })
147    }
148}
149
150/// A running shard with its task handle.
151struct ShardRunner {
152    /// The shard instance.
153    shard: Arc<Shard>,
154    /// The task handle for the shard's event loop.
155    handle: JoinHandle<Result<(), GatewayError>>,
156}
157
158/// A Cluster manages multiple Gateway Shards.
159///
160/// The Cluster handles:
161/// - Spawning and managing shard tasks
162/// - Coordinating identify rate limiting across shards
163/// - Aggregating events from all shards
164///
165/// # Example
166///
167/// ```ignore
168/// use titanium_gateway::{Cluster, ClusterConfig, ShardRange};
169/// use titanium_model::Intents;
170///
171/// #[tokio::main]
172/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
173///     let config = ClusterConfig::new(
174///         "your-token",
175///         Intents::GUILDS | Intents::GUILD_MESSAGES,
176///         ShardRange::All { total: 1 },
177///     );
178///
179///     let (cluster, mut events) = Cluster::new(config);
180///     cluster.start().await?;
181///
182///     while let Ok((shard_id, event)) = events.recv_async().await {
183///         println!("Shard {}: {:?}", shard_id, event);
184///     }
185///
186///     Ok(())
187/// }
188/// ```
189pub struct Cluster {
190    /// Cluster configuration.
191    config: ClusterConfig,
192
193    /// Running shards.
194    shards: DashMap<u16, ShardRunner>,
195
196    /// Shared rate limiter for identify.
197    rate_limiter: Arc<IdentifyRateLimiter>,
198
199    /// Channel to send shard events.
200    event_tx: Sender<(u16, Event<'static>)>,
201}
202
203impl Cluster {
204    /// Create a new Cluster.
205    ///
206    /// Returns the Cluster and a receiver for events from all shards.
207    /// Events are tagged with the shard ID they came from.
208    pub fn new(config: ClusterConfig) -> (Self, Receiver<(u16, Event<'static>)>) {
209        let (event_tx, event_rx) = flume::unbounded();
210        let rate_limiter = Arc::new(IdentifyRateLimiter::new(config.max_concurrency));
211
212        let cluster = Self {
213            config,
214            shards: DashMap::new(),
215            rate_limiter,
216            event_tx,
217        };
218
219        (cluster, event_rx)
220    }
221
222    /// Start all shards.
223    ///
224    /// This spawns a task for each shard and begins connecting to Discord.
225    /// Shards will connect with proper rate limiting based on `max_concurrency`.
226    pub async fn start(&self) -> Result<(), GatewayError> {
227        let shard_ids = self.config.shard_range.shard_ids();
228        let total_shards = self.config.shard_range.total_shards();
229
230        info!(
231            shards = ?shard_ids,
232            total = total_shards,
233            max_concurrency = self.config.max_concurrency,
234            "Starting cluster"
235        );
236
237        for shard_id in shard_ids {
238            self.spawn_shard(shard_id, total_shards)?;
239        }
240
241        Ok(())
242    }
243
244    /// Spawn a single shard.
245    fn spawn_shard(&self, shard_id: u16, total_shards: u16) -> Result<(), GatewayError> {
246        let shard_config = ShardConfig {
247            token: self.config.token.clone(),
248            intents: self.config.intents,
249            gateway_url: self.config.gateway_url.clone(),
250            large_threshold: self.config.large_threshold,
251            compress: false,
252            max_reconnect_attempts: 10,
253            reconnect_base_delay_ms: 1000,
254            reconnect_max_delay_ms: 60000,
255        };
256
257        let shard = Arc::new(Shard::with_rate_limiter(
258            shard_id,
259            total_shards,
260            shard_config,
261            self.rate_limiter.clone(),
262        ));
263
264        // Create per-shard event channel that forwards to cluster channel
265        let (shard_tx, shard_rx) = flume::unbounded::<Event>();
266        let cluster_tx = self.event_tx.clone();
267        let shard_id_for_forward = shard_id;
268
269        // Spawn forwarding task
270        tokio::spawn(async move {
271            while let Ok(event) = shard_rx.recv_async().await {
272                if cluster_tx
273                    .send_async((shard_id_for_forward, event))
274                    .await
275                    .is_err()
276                {
277                    break;
278                }
279            }
280        });
281
282        // Spawn shard task
283        let shard_clone = shard.clone();
284        let handle = tokio::spawn(async move { shard_clone.run(shard_tx).await });
285
286        self.shards.insert(shard_id, ShardRunner { shard, handle });
287
288        info!(shard_id = shard_id, "Shard spawned");
289        Ok(())
290    }
291
292    /// Get the state of a specific shard.
293    pub fn shard_state(&self, shard_id: u16) -> Option<ShardState> {
294        self.shards.get(&shard_id).map(|r| r.shard.state())
295    }
296
297    /// Get the last measured latency for a specific shard.
298    pub fn shard_latency(&self, shard_id: u16) -> Option<std::time::Duration> {
299        self.shards.get(&shard_id).and_then(|r| r.shard.latency())
300    }
301
302    /// Get all shard IDs managed by this cluster.
303    pub fn shard_ids(&self) -> Vec<u16> {
304        self.shards.iter().map(|r| *r.key()).collect()
305    }
306
307    /// Send a raw payload to a specific shard.
308    pub fn send(&self, shard_id: u16, payload: serde_json::Value) -> Result<(), GatewayError> {
309        if let Some(runner) = self.shards.get(&shard_id) {
310            runner.shard.send_payload(&payload)
311        } else {
312            Err(GatewayError::Closed {
313                code: 0,
314                reason: format!("Shard {} not found", shard_id),
315            })
316        }
317    }
318
319    /// Shutdown all shards gracefully.
320    pub async fn shutdown(&self) {
321        info!("Shutting down cluster");
322
323        // Request shutdown for all shards
324        for shard in self.shards.iter() {
325            shard.shard.shutdown();
326        }
327
328        // Wait for all shard tasks to complete
329        for mut entry in self.shards.iter_mut() {
330            let runner = entry.value_mut();
331            if let Err(e) = (&mut runner.handle).await {
332                error!(shard_id = *entry.key(), error = %e, "Shard task panicked");
333            }
334        }
335
336        info!("Cluster shutdown complete");
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn test_shard_range_all() {
346        let range = ShardRange::All { total: 10 };
347        let ids = range.shard_ids();
348        assert_eq!(ids.len(), 10);
349        assert_eq!(ids[0], 0);
350        assert_eq!(ids[9], 9);
351    }
352
353    #[test]
354    fn test_shard_range_specific() {
355        let range = ShardRange::Specific {
356            ids: vec![0, 5, 10],
357            total: 20,
358        };
359        let ids = range.shard_ids();
360        assert_eq!(ids, vec![0, 5, 10]);
361        assert_eq!(range.total_shards(), 20);
362    }
363
364    #[test]
365    fn test_cluster_config() {
366        let config = ClusterConfig::new(
367            "test_token",
368            Intents::GUILDS,
369            ShardRange::Range {
370                start: 0,
371                end: 5,
372                total: 10,
373            },
374        )
375        .with_max_concurrency(16);
376
377        assert_eq!(config.max_concurrency, 16);
378        assert_eq!(config.shard_range.shard_ids().len(), 5);
379    }
380}