1use std::future::Future;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9use asupersync::{Cx, Outcome};
10use sqlmodel_core::{Connection, Error};
11
12use crate::{Pool, PooledConnection};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ReplicaStrategy {
17 RoundRobin,
19 Random,
21}
22
23pub struct ReplicaPool<C: Connection> {
41 primary: Pool<C>,
43 replicas: Vec<Pool<C>>,
45 strategy: ReplicaStrategy,
47 round_robin_counter: AtomicUsize,
49}
50
51impl<C: Connection> ReplicaPool<C> {
52 pub fn new(primary: Pool<C>, replicas: Vec<Pool<C>>) -> Self {
54 Self {
55 primary,
56 replicas,
57 strategy: ReplicaStrategy::RoundRobin,
58 round_robin_counter: AtomicUsize::new(0),
59 }
60 }
61
62 pub fn with_strategy(
64 primary: Pool<C>,
65 replicas: Vec<Pool<C>>,
66 strategy: ReplicaStrategy,
67 ) -> Self {
68 Self {
69 primary,
70 replicas,
71 strategy,
72 round_robin_counter: AtomicUsize::new(0),
73 }
74 }
75
76 pub async fn acquire_read<F, Fut>(
81 &self,
82 cx: &Cx,
83 factory: F,
84 ) -> Outcome<PooledConnection<C>, Error>
85 where
86 F: Fn() -> Fut,
87 Fut: Future<Output = Outcome<C, Error>>,
88 {
89 if self.replicas.is_empty() {
90 return self.primary.acquire(cx, factory).await;
91 }
92
93 let idx = self.select_replica();
94 self.replicas[idx].acquire(cx, factory).await
95 }
96
97 pub async fn acquire_write<F, Fut>(
99 &self,
100 cx: &Cx,
101 factory: F,
102 ) -> Outcome<PooledConnection<C>, Error>
103 where
104 F: Fn() -> Fut,
105 Fut: Future<Output = Outcome<C, Error>>,
106 {
107 self.primary.acquire(cx, factory).await
108 }
109
110 pub async fn acquire_primary<F, Fut>(
112 &self,
113 cx: &Cx,
114 factory: F,
115 ) -> Outcome<PooledConnection<C>, Error>
116 where
117 F: Fn() -> Fut,
118 Fut: Future<Output = Outcome<C, Error>>,
119 {
120 self.primary.acquire(cx, factory).await
121 }
122
123 pub fn primary(&self) -> &Pool<C> {
125 &self.primary
126 }
127
128 pub fn replicas(&self) -> &[Pool<C>] {
130 &self.replicas
131 }
132
133 pub fn replica_count(&self) -> usize {
135 self.replicas.len()
136 }
137
138 pub fn strategy(&self) -> ReplicaStrategy {
140 self.strategy
141 }
142
143 fn select_replica(&self) -> usize {
144 match self.strategy {
145 ReplicaStrategy::RoundRobin => {
146 let idx = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
147 idx % self.replicas.len()
148 }
149 ReplicaStrategy::Random => {
150 let seq = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
153 #[allow(clippy::cast_possible_truncation)]
156 let seq32 = seq as u32;
157 let mixed = seq32.wrapping_mul(2_654_435_761_u32);
158 (mixed as usize) % self.replicas.len()
159 }
160 }
161 }
162}
163
164impl<C: Connection> std::fmt::Debug for ReplicaPool<C> {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 f.debug_struct("ReplicaPool")
167 .field("primary", &"Pool { .. }")
168 .field("replicas", &self.replicas.len())
169 .field("strategy", &self.strategy)
170 .field(
171 "round_robin_counter",
172 &self.round_robin_counter.load(Ordering::Relaxed),
173 )
174 .finish()
175 }
176}