sqlmodel_pool/sharding.rs
1//! Horizontal sharding support for SQLModel Rust.
2//!
3//! This module provides infrastructure for partitioning data across multiple
4//! database shards based on a shard key.
5//!
6//! # Overview
7//!
8//! Horizontal sharding distributes rows across multiple databases based on a
9//! shard key (e.g., `user_id`, `tenant_id`). This enables:
10//!
11//! - Horizontal scalability beyond single-database limits
12//! - Data isolation between tenants/regions
13//! - Improved query performance through data locality
14//!
15//! # Example
16//!
17//! ```rust,ignore
18//! use sqlmodel_pool::{Pool, PoolConfig, ShardedPool, ShardChooser};
19//! use sqlmodel_core::{Model, Value};
20//!
21//! // Define a shard chooser based on modulo hashing
22//! struct ModuloShardChooser {
23//! shard_count: usize,
24//! }
25//!
26//! impl ShardChooser for ModuloShardChooser {
27//! fn choose_for_model(&self, shard_key: &Value) -> String {
28//! let id = match shard_key {
29//! Value::BigInt(n) => *n as usize,
30//! Value::Int(n) => *n as usize,
31//! _ => 0,
32//! };
33//! format!("shard_{}", id % self.shard_count)
34//! }
35//!
36//! fn choose_for_query(&self, _hints: &QueryHints) -> Vec<String> {
37//! // Query all shards by default
38//! (0..self.shard_count)
39//! .map(|i| format!("shard_{}", i))
40//! .collect()
41//! }
42//! }
43//!
44//! // Create sharded pool
45//! let mut sharded_pool = ShardedPool::new(ModuloShardChooser { shard_count: 3 });
46//! sharded_pool.add_shard("shard_0", pool_0);
47//! sharded_pool.add_shard("shard_1", pool_1);
48//! sharded_pool.add_shard("shard_2", pool_2);
49//!
50//! // Insert routes to correct shard based on model's shard key
51//! let order = Order { user_id: 42, ... };
52//! let shard = sharded_pool.choose_for_model(&order);
53//! ```
54
55use std::collections::HashMap;
56use std::future::Future;
57use std::sync::Arc;
58
59use asupersync::{Cx, Outcome};
60use sqlmodel_core::error::{PoolError, PoolErrorKind};
61use sqlmodel_core::{Connection, Error, Model, Value};
62
63use crate::{Pool, PoolConfig, PooledConnection};
64
65/// Hints for query routing when a specific shard key isn't available.
66///
67/// When executing queries that don't have a clear shard key (e.g., range queries,
68/// aggregations), these hints help the `ShardChooser` decide which shards to query.
69#[derive(Debug, Clone, Default)]
70pub struct QueryHints {
71 /// Specific shard names to target (if known).
72 pub target_shards: Option<Vec<String>>,
73
74 /// Whether to query all shards (scatter-gather).
75 pub scatter_gather: bool,
76
77 /// Optional shard key value extracted from query predicates.
78 pub shard_key_value: Option<Value>,
79
80 /// Query type hint (e.g., "select", "aggregate", "count").
81 pub query_type: Option<String>,
82}
83
84impl QueryHints {
85 /// Create empty hints (defaults to scatter-gather).
86 #[must_use]
87 pub fn new() -> Self {
88 Self::default()
89 }
90
91 /// Target specific shards by name.
92 #[must_use]
93 pub fn target(mut self, shards: Vec<String>) -> Self {
94 self.target_shards = Some(shards);
95 self
96 }
97
98 /// Enable scatter-gather mode (query all shards).
99 #[must_use]
100 pub fn scatter_gather(mut self) -> Self {
101 self.scatter_gather = true;
102 self
103 }
104
105 /// Provide a shard key value for routing.
106 #[must_use]
107 pub fn with_shard_key(mut self, value: Value) -> Self {
108 self.shard_key_value = Some(value);
109 self
110 }
111
112 /// Set the query type hint.
113 #[must_use]
114 pub fn query_type(mut self, query_type: impl Into<String>) -> Self {
115 self.query_type = Some(query_type.into());
116 self
117 }
118}
119
120/// Trait for determining which shard(s) to use for operations.
121///
122/// Implement this trait to define your sharding strategy. Common strategies:
123///
124/// - **Modulo hashing**: `shard_key % shard_count`
125/// - **Range-based**: Partition by key ranges (e.g., user IDs 0-1M → shard_0)
126/// - **Consistent hashing**: Minimize rebalancing when adding/removing shards
127/// - **Tenant-based**: Map tenant IDs directly to shard names
128///
129/// # Example
130///
131/// ```rust,ignore
132/// struct TenantShardChooser {
133/// tenant_to_shard: HashMap<String, String>,
134/// default_shard: String,
135/// }
136///
137/// impl ShardChooser for TenantShardChooser {
138/// fn choose_for_model(&self, shard_key: &Value) -> String {
139/// if let Value::Text(tenant_id) = shard_key {
140/// self.tenant_to_shard
141/// .get(tenant_id)
142/// .cloned()
143/// .unwrap_or_else(|| self.default_shard.clone())
144/// } else {
145/// self.default_shard.clone()
146/// }
147/// }
148///
149/// fn choose_for_query(&self, hints: &QueryHints) -> Vec<String> {
150/// if let Some(Value::Text(tenant_id)) = &hints.shard_key_value {
151/// vec![self.choose_for_model(&Value::Text(tenant_id.clone()))]
152/// } else {
153/// // Query all shards
154/// self.tenant_to_shard.values().cloned().collect()
155/// }
156/// }
157/// }
158/// ```
159pub trait ShardChooser: Send + Sync {
160 /// Choose the shard for a model based on its shard key value.
161 ///
162 /// This is used for INSERT, UPDATE, and DELETE operations where the
163 /// shard key is known from the model instance.
164 ///
165 /// # Arguments
166 ///
167 /// * `shard_key` - The value of the model's shard key field
168 ///
169 /// # Returns
170 ///
171 /// The name of the shard to use (must match a shard registered in `ShardedPool`).
172 fn choose_for_model(&self, shard_key: &Value) -> String;
173
174 /// Choose which shards to query based on query hints.
175 ///
176 /// For queries where the shard key isn't directly available (e.g., range
177 /// queries, joins, aggregations), this method returns the list of shards
178 /// to query.
179 ///
180 /// # Arguments
181 ///
182 /// * `hints` - Query routing hints (target shards, shard key value, etc.)
183 ///
184 /// # Returns
185 ///
186 /// List of shard names to query. For point queries with a known shard key,
187 /// this should return a single shard. For scatter-gather, return all shards.
188 fn choose_for_query(&self, hints: &QueryHints) -> Vec<String>;
189
190 /// Get all registered shard names.
191 ///
192 /// Default implementation returns an empty vec; override if your chooser
193 /// tracks shard names internally.
194 fn all_shards(&self) -> Vec<String> {
195 vec![]
196 }
197}
198
199/// A simple modulo-based shard chooser for numeric shard keys.
200///
201/// Routes based on `shard_key % shard_count`, producing shard names like
202/// `shard_0`, `shard_1`, etc.
203///
204/// This is suitable for evenly distributed numeric keys (e.g., auto-increment IDs).
205/// Not suitable for sequential inserts (hotspotting on latest shard) or
206/// non-numeric keys.
207#[derive(Debug, Clone)]
208pub struct ModuloShardChooser {
209 shard_count: usize,
210 shard_prefix: String,
211}
212
213impl ModuloShardChooser {
214 /// Create a new modulo shard chooser with the given number of shards.
215 ///
216 /// Shards are named `shard_0`, `shard_1`, ..., `shard_{n-1}`.
217 ///
218 /// # Panics
219 ///
220 /// Panics if `shard_count` is 0, as this would cause division by zero
221 /// when routing to shards.
222 #[must_use]
223 pub fn new(shard_count: usize) -> Self {
224 assert!(shard_count > 0, "shard_count must be greater than 0");
225 Self {
226 shard_count,
227 shard_prefix: "shard_".to_string(),
228 }
229 }
230
231 /// Set a custom prefix for shard names (default: "shard_").
232 #[must_use]
233 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
234 self.shard_prefix = prefix.into();
235 self
236 }
237
238 /// Get the shard count.
239 #[must_use]
240 pub fn shard_count(&self) -> usize {
241 self.shard_count
242 }
243
244 /// Extract a numeric value from a Value for modulo calculation.
245 ///
246 /// Truncation on 32-bit platforms is acceptable here since we only need
247 /// the value for consistent shard routing via modulo.
248 #[allow(clippy::cast_possible_truncation)]
249 fn extract_numeric(&self, value: &Value) -> usize {
250 match value {
251 Value::BigInt(n) => (*n).unsigned_abs() as usize,
252 Value::Int(n) => (*n).unsigned_abs() as usize,
253 Value::SmallInt(n) => (*n).unsigned_abs() as usize,
254 Value::Text(s) => {
255 // Hash the string for non-numeric keys
256 use std::hash::{Hash, Hasher};
257 let mut hasher = std::collections::hash_map::DefaultHasher::new();
258 s.hash(&mut hasher);
259 hasher.finish() as usize
260 }
261 _ => 0,
262 }
263 }
264}
265
266impl ShardChooser for ModuloShardChooser {
267 fn choose_for_model(&self, shard_key: &Value) -> String {
268 let n = self.extract_numeric(shard_key);
269 format!("{}{}", self.shard_prefix, n % self.shard_count)
270 }
271
272 fn choose_for_query(&self, hints: &QueryHints) -> Vec<String> {
273 // If specific shards are targeted, use those
274 if let Some(ref targets) = hints.target_shards {
275 return targets.clone();
276 }
277
278 // If shard key is available, route to specific shard
279 if let Some(ref value) = hints.shard_key_value {
280 return vec![self.choose_for_model(value)];
281 }
282
283 // Default: scatter-gather to all shards
284 self.all_shards()
285 }
286
287 fn all_shards(&self) -> Vec<String> {
288 (0..self.shard_count)
289 .map(|i| format!("{}{}", self.shard_prefix, i))
290 .collect()
291 }
292}
293
294/// A sharded connection pool that routes operations to the correct shard.
295///
296/// `ShardedPool` wraps multiple `Pool` instances, one per shard, and uses
297/// a `ShardChooser` to determine which shard to use for each operation.
298///
299/// # Example
300///
301/// ```rust,ignore
302/// // Create pools for each shard
303/// let pool_0 = Pool::new(PoolConfig::new(10));
304/// let pool_1 = Pool::new(PoolConfig::new(10));
305///
306/// // Create sharded pool with modulo chooser
307/// let chooser = ModuloShardChooser::new(2);
308/// let mut sharded = ShardedPool::new(chooser);
309/// sharded.add_shard("shard_0", pool_0);
310/// sharded.add_shard("shard_1", pool_1);
311///
312/// // Acquire connection from specific shard
313/// let conn = sharded.acquire_for_model(&cx, &order, factory).await?;
314/// ```
315pub struct ShardedPool<C: Connection, S: ShardChooser> {
316 shards: HashMap<String, Pool<C>>,
317 chooser: Arc<S>,
318}
319
320impl<C: Connection, S: ShardChooser> ShardedPool<C, S> {
321 /// Create a new sharded pool with the given shard chooser.
322 pub fn new(chooser: S) -> Self {
323 Self {
324 shards: HashMap::new(),
325 chooser: Arc::new(chooser),
326 }
327 }
328
329 /// Add a shard to the pool.
330 ///
331 /// # Arguments
332 ///
333 /// * `name` - The shard name (must match names returned by the chooser)
334 /// * `pool` - The connection pool for this shard
335 pub fn add_shard(&mut self, name: impl Into<String>, pool: Pool<C>) {
336 self.shards.insert(name.into(), pool);
337 }
338
339 /// Add a shard with a new pool created from the given config.
340 pub fn add_shard_with_config(&mut self, name: impl Into<String>, config: PoolConfig) {
341 self.shards.insert(name.into(), Pool::new(config));
342 }
343
344 /// Get a reference to the shard chooser.
345 pub fn chooser(&self) -> &S {
346 &self.chooser
347 }
348
349 /// Get a reference to a specific shard's pool.
350 pub fn get_shard(&self, name: &str) -> Option<&Pool<C>> {
351 self.shards.get(name)
352 }
353
354 /// Get all shard names.
355 pub fn shard_names(&self) -> Vec<String> {
356 self.shards.keys().cloned().collect()
357 }
358
359 /// Get the number of shards.
360 pub fn shard_count(&self) -> usize {
361 self.shards.len()
362 }
363
364 /// Check if a shard exists.
365 pub fn has_shard(&self, name: &str) -> bool {
366 self.shards.contains_key(name)
367 }
368
369 /// Choose the shard for a model based on its shard key.
370 ///
371 /// Returns the shard name. Use this when you need to know the shard
372 /// without acquiring a connection.
373 #[allow(clippy::result_large_err)]
374 pub fn choose_for_model<M: Model>(&self, model: &M) -> Result<String, Error> {
375 let shard_key = model.shard_key_value().ok_or_else(|| {
376 Error::Pool(PoolError {
377 kind: PoolErrorKind::Config,
378 message: format!(
379 "Model {} has no shard key defined; add #[sqlmodel(shard_key = \"field\")]",
380 M::TABLE_NAME
381 ),
382 source: None,
383 })
384 })?;
385 Ok(self.chooser.choose_for_model(&shard_key))
386 }
387
388 /// Choose shards for a query based on hints.
389 pub fn choose_for_query(&self, hints: &QueryHints) -> Vec<String> {
390 self.chooser.choose_for_query(hints)
391 }
392
393 /// Acquire a connection from the shard determined by the model's shard key.
394 ///
395 /// # Arguments
396 ///
397 /// * `cx` - The async context
398 /// * `model` - The model instance (must have a shard key)
399 /// * `factory` - Connection factory function
400 ///
401 /// # Errors
402 ///
403 /// Returns an error if:
404 /// - The model has no shard key
405 /// - The determined shard doesn't exist
406 /// - Connection acquisition fails
407 pub async fn acquire_for_model<M, F, Fut>(
408 &self,
409 cx: &Cx,
410 model: &M,
411 factory: F,
412 ) -> Outcome<PooledConnection<C>, Error>
413 where
414 M: Model,
415 F: Fn() -> Fut,
416 Fut: Future<Output = Outcome<C, Error>>,
417 {
418 let shard_name = match self.choose_for_model(model) {
419 Ok(name) => name,
420 Err(e) => return Outcome::Err(e),
421 };
422
423 self.acquire_from_shard(cx, &shard_name, factory).await
424 }
425
426 /// Acquire a connection from a specific shard by name.
427 ///
428 /// # Arguments
429 ///
430 /// * `cx` - The async context
431 /// * `shard_name` - The name of the shard to acquire from
432 /// * `factory` - Connection factory function
433 ///
434 /// # Errors
435 ///
436 /// Returns an error if:
437 /// - The shard doesn't exist
438 /// - Connection acquisition fails
439 pub async fn acquire_from_shard<F, Fut>(
440 &self,
441 cx: &Cx,
442 shard_name: &str,
443 factory: F,
444 ) -> Outcome<PooledConnection<C>, Error>
445 where
446 F: Fn() -> Fut,
447 Fut: Future<Output = Outcome<C, Error>>,
448 {
449 let Some(pool) = self.shards.get(shard_name) else {
450 return Outcome::Err(Error::Pool(PoolError {
451 kind: PoolErrorKind::Config,
452 message: format!(
453 "shard '{}' not found; available shards: {:?}",
454 shard_name,
455 self.shard_names()
456 ),
457 source: None,
458 }));
459 };
460
461 pool.acquire(cx, factory).await
462 }
463
464 /// Acquire connections from multiple shards for scatter-gather queries.
465 ///
466 /// Returns a map of shard name to pooled connection for each successfully
467 /// acquired connection. Failed acquisitions are logged but don't fail the
468 /// entire operation.
469 ///
470 /// # Arguments
471 ///
472 /// * `cx` - The async context
473 /// * `hints` - Query routing hints
474 /// * `factory` - Connection factory function
475 pub async fn acquire_for_query<F, Fut>(
476 &self,
477 cx: &Cx,
478 hints: &QueryHints,
479 factory: F,
480 ) -> Result<HashMap<String, PooledConnection<C>>, Error>
481 where
482 F: Fn() -> Fut + Clone,
483 Fut: Future<Output = Outcome<C, Error>>,
484 {
485 let target_shards = self.choose_for_query(hints);
486 let mut connections = HashMap::new();
487
488 for shard_name in target_shards {
489 match self
490 .acquire_from_shard(cx, &shard_name, factory.clone())
491 .await
492 {
493 Outcome::Ok(conn) => {
494 connections.insert(shard_name, conn);
495 }
496 Outcome::Err(e) => {
497 tracing::warn!(shard = %shard_name, error = %e, "Failed to acquire connection from shard");
498 }
499 Outcome::Cancelled(reason) => {
500 tracing::debug!(shard = %shard_name, reason = ?reason, "Cancelled while acquiring from shard");
501 }
502 Outcome::Panicked(info) => {
503 tracing::error!(shard = %shard_name, panic = ?info, "Panic while acquiring from shard");
504 }
505 }
506 }
507
508 if connections.is_empty() {
509 return Err(Error::Pool(PoolError {
510 kind: PoolErrorKind::Exhausted,
511 message: "failed to acquire connection from any shard".to_string(),
512 source: None,
513 }));
514 }
515
516 Ok(connections)
517 }
518
519 /// Close all shards.
520 pub fn close(&self) {
521 for pool in self.shards.values() {
522 pool.close();
523 }
524 }
525
526 /// Check if all shards are closed.
527 pub fn is_closed(&self) -> bool {
528 self.shards.values().all(|p| p.is_closed())
529 }
530
531 /// Get aggregate statistics across all shards.
532 pub fn stats(&self) -> ShardedPoolStats {
533 let mut total = ShardedPoolStats::default();
534
535 for (name, pool) in &self.shards {
536 let shard_stats = pool.stats();
537 total.per_shard.insert(name.clone(), shard_stats.clone());
538 total.total_connections += shard_stats.total_connections;
539 total.idle_connections += shard_stats.idle_connections;
540 total.active_connections += shard_stats.active_connections;
541 total.pending_requests += shard_stats.pending_requests;
542 total.connections_created += shard_stats.connections_created;
543 total.connections_closed += shard_stats.connections_closed;
544 total.acquires += shard_stats.acquires;
545 total.timeouts += shard_stats.timeouts;
546 }
547
548 total.shard_count = self.shards.len();
549 total
550 }
551}
552
553/// Aggregate statistics for a sharded pool.
554#[derive(Debug, Clone, Default)]
555pub struct ShardedPoolStats {
556 /// Number of shards.
557 pub shard_count: usize,
558 /// Per-shard statistics.
559 pub per_shard: HashMap<String, crate::PoolStats>,
560 /// Total connections across all shards.
561 pub total_connections: usize,
562 /// Idle connections across all shards.
563 pub idle_connections: usize,
564 /// Active connections across all shards.
565 pub active_connections: usize,
566 /// Pending requests across all shards.
567 pub pending_requests: usize,
568 /// Total connections created across all shards.
569 pub connections_created: u64,
570 /// Total connections closed across all shards.
571 pub connections_closed: u64,
572 /// Total acquires across all shards.
573 pub acquires: u64,
574 /// Total timeouts across all shards.
575 pub timeouts: u64,
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581
582 #[test]
583 fn test_query_hints_builder() {
584 let hints = QueryHints::new()
585 .target(vec!["shard_0".to_string()])
586 .with_shard_key(Value::BigInt(42))
587 .query_type("select");
588
589 assert_eq!(hints.target_shards, Some(vec!["shard_0".to_string()]));
590 assert_eq!(hints.shard_key_value, Some(Value::BigInt(42)));
591 assert_eq!(hints.query_type, Some("select".to_string()));
592 }
593
594 #[test]
595 fn test_query_hints_scatter_gather() {
596 let hints = QueryHints::new().scatter_gather();
597 assert!(hints.scatter_gather);
598 }
599
600 #[test]
601 fn test_modulo_shard_chooser_new() {
602 let chooser = ModuloShardChooser::new(4);
603 assert_eq!(chooser.shard_count(), 4);
604 }
605
606 #[test]
607 fn test_modulo_shard_chooser_with_prefix() {
608 let chooser = ModuloShardChooser::new(3).with_prefix("db_");
609 assert_eq!(
610 chooser.choose_for_model(&Value::BigInt(0)),
611 "db_0".to_string()
612 );
613 assert_eq!(
614 chooser.choose_for_model(&Value::BigInt(1)),
615 "db_1".to_string()
616 );
617 }
618
619 #[test]
620 fn test_modulo_shard_chooser_choose_for_model() {
621 let chooser = ModuloShardChooser::new(3);
622
623 assert_eq!(chooser.choose_for_model(&Value::BigInt(0)), "shard_0");
624 assert_eq!(chooser.choose_for_model(&Value::BigInt(1)), "shard_1");
625 assert_eq!(chooser.choose_for_model(&Value::BigInt(2)), "shard_2");
626 assert_eq!(chooser.choose_for_model(&Value::BigInt(3)), "shard_0");
627 assert_eq!(chooser.choose_for_model(&Value::BigInt(100)), "shard_1");
628 }
629
630 #[test]
631 fn test_modulo_shard_chooser_int_types() {
632 let chooser = ModuloShardChooser::new(2);
633
634 assert_eq!(chooser.choose_for_model(&Value::Int(5)), "shard_1");
635 assert_eq!(chooser.choose_for_model(&Value::SmallInt(4)), "shard_0");
636 }
637
638 #[test]
639 fn test_modulo_shard_chooser_negative_values() {
640 let chooser = ModuloShardChooser::new(3);
641
642 // Negative values should use absolute value
643 assert_eq!(chooser.choose_for_model(&Value::BigInt(-1)), "shard_1");
644 assert_eq!(chooser.choose_for_model(&Value::BigInt(-3)), "shard_0");
645 }
646
647 #[test]
648 fn test_modulo_shard_chooser_string_hash() {
649 let chooser = ModuloShardChooser::new(3);
650
651 // Strings should be hashed consistently
652 let shard1 = chooser.choose_for_model(&Value::Text("user_abc".to_string()));
653 let shard2 = chooser.choose_for_model(&Value::Text("user_abc".to_string()));
654 assert_eq!(shard1, shard2);
655
656 // Different strings may hash to same or different shards
657 let _ = chooser.choose_for_model(&Value::Text("user_xyz".to_string()));
658 }
659
660 #[test]
661 fn test_modulo_shard_chooser_all_shards() {
662 let chooser = ModuloShardChooser::new(3);
663 let all = chooser.all_shards();
664
665 assert_eq!(all.len(), 3);
666 assert!(all.contains(&"shard_0".to_string()));
667 assert!(all.contains(&"shard_1".to_string()));
668 assert!(all.contains(&"shard_2".to_string()));
669 }
670
671 #[test]
672 fn test_modulo_shard_chooser_choose_for_query_with_key() {
673 let chooser = ModuloShardChooser::new(3);
674 let hints = QueryHints::new().with_shard_key(Value::BigInt(5));
675
676 let shards = chooser.choose_for_query(&hints);
677 assert_eq!(shards.len(), 1);
678 assert_eq!(shards[0], "shard_2"); // 5 % 3 = 2
679 }
680
681 #[test]
682 fn test_modulo_shard_chooser_choose_for_query_scatter() {
683 let chooser = ModuloShardChooser::new(3);
684 let hints = QueryHints::new().scatter_gather();
685
686 let shards = chooser.choose_for_query(&hints);
687 assert_eq!(shards.len(), 3);
688 }
689
690 #[test]
691 fn test_modulo_shard_chooser_choose_for_query_target() {
692 let chooser = ModuloShardChooser::new(3);
693 let hints = QueryHints::new().target(vec!["shard_1".to_string()]);
694
695 let shards = chooser.choose_for_query(&hints);
696 assert_eq!(shards, vec!["shard_1"]);
697 }
698
699 #[test]
700 fn test_sharded_pool_stats_default() {
701 let stats = ShardedPoolStats::default();
702 assert_eq!(stats.shard_count, 0);
703 assert_eq!(stats.total_connections, 0);
704 assert!(stats.per_shard.is_empty());
705 }
706
707 #[test]
708 #[should_panic(expected = "shard_count must be greater than 0")]
709 fn test_modulo_shard_chooser_zero_shards_panics() {
710 let _ = ModuloShardChooser::new(0);
711 }
712}