squads_temporal_client/
replaceable.rs1use crate::NamespacedClient;
2use std::{
3 borrow::Cow,
4 sync::{
5 Arc, RwLock,
6 atomic::{AtomicU32, Ordering},
7 },
8};
9
10#[derive(Debug)]
18pub struct SharedReplaceableClient<C>
19where
20 C: Clone + Send + Sync,
21{
22 shared_data: Arc<SharedClientData<C>>,
23 cloned_client: C,
24 cloned_generation: u32,
25}
26
27#[derive(Debug)]
28struct SharedClientData<C>
29where
30 C: Clone + Send + Sync,
31{
32 client: RwLock<C>,
33 generation: AtomicU32,
34}
35
36impl<C> SharedClientData<C>
37where
38 C: Clone + Send + Sync,
39{
40 fn fetch(&self) -> (C, u32) {
41 let lock = self.client.read().unwrap();
42 let client = lock.clone();
43 let generation = self.generation.load(Ordering::Acquire);
45 (client, generation)
46 }
47
48 fn fetch_newer_than(&self, current_generation: u32) -> Option<(C, u32)> {
49 (current_generation != self.generation.load(Ordering::Acquire)).then(|| self.fetch())
51 }
52
53 fn replace_client(&self, client: C) {
54 let mut lock = self.client.write().unwrap();
55 *lock = client;
56 self.generation.fetch_add(1, Ordering::AcqRel);
59 }
60}
61
62impl<C> SharedReplaceableClient<C>
63where
64 C: Clone + Send + Sync,
65{
66 pub fn new(client: C) -> Self {
69 let cloned_client = client.clone();
70 Self {
71 shared_data: Arc::new(SharedClientData {
72 client: RwLock::new(client),
73 generation: AtomicU32::new(0),
74 }),
75 cloned_client,
76 cloned_generation: 0,
77 }
78 }
79
80 pub fn replace_client(&self, new_client: C) {
82 self.shared_data.replace_client(new_client); }
84
85 pub fn inner_clone(&self) -> C {
87 self.inner_cow().into_owned()
88 }
89
90 pub fn inner_cow(&self) -> Cow<'_, C> {
95 self.shared_data
96 .fetch_newer_than(self.cloned_generation)
97 .map(|(c, _)| Cow::Owned(c))
98 .unwrap_or_else(|| Cow::Borrowed(&self.cloned_client))
99 }
100
101 pub fn inner_mut_refreshed(&mut self) -> &mut C {
110 if let Some((client, generation)) =
111 self.shared_data.fetch_newer_than(self.cloned_generation)
112 {
113 self.cloned_client = client;
114 self.cloned_generation = generation;
115 }
116 &mut self.cloned_client
117 }
118}
119
120impl<C> Clone for SharedReplaceableClient<C>
121where
122 C: Clone + Send + Sync,
123{
124 fn clone(&self) -> Self {
128 let (client, generation) = self.shared_data.fetch();
131 Self {
132 shared_data: self.shared_data.clone(),
133 cloned_client: client,
134 cloned_generation: generation,
135 }
136 }
137}
138
139impl<C> NamespacedClient for SharedReplaceableClient<C>
140where
141 C: NamespacedClient + Clone + Send + Sync,
142{
143 fn namespace(&self) -> String {
144 self.inner_cow().namespace()
145 }
146
147 fn identity(&self) -> String {
148 self.inner_cow().identity()
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::NamespacedClient;
156 use std::borrow::Cow;
157
158 #[derive(Debug, Clone)]
159 struct StubClient {
160 identity: String,
161 }
162
163 impl StubClient {
164 fn new(identity: &str) -> Self {
165 Self {
166 identity: identity.to_owned(),
167 }
168 }
169 }
170
171 impl NamespacedClient for StubClient {
172 fn namespace(&self) -> String {
173 "default".into()
174 }
175
176 fn identity(&self) -> String {
177 self.identity.clone()
178 }
179 }
180
181 #[test]
182 fn cow_returns_reference_before_and_clone_after_refresh() {
183 let mut client = SharedReplaceableClient::new(StubClient::new("1"));
184 let Cow::Borrowed(inner) = client.inner_cow() else {
185 panic!("expected borrowed inner");
186 };
187 assert_eq!(inner.identity, "1");
188
189 client.replace_client(StubClient::new("2"));
190 let Cow::Owned(inner) = client.inner_cow() else {
191 panic!("expected owned inner");
192 };
193 assert_eq!(inner.identity, "2");
194
195 assert_eq!(client.inner_mut_refreshed().identity, "2");
196 let Cow::Borrowed(inner) = client.inner_cow() else {
197 panic!("expected borrowed inner");
198 };
199 assert_eq!(inner.identity, "2");
200 }
201
202 #[test]
203 fn client_replaced_in_clones() {
204 let original1 = SharedReplaceableClient::new(StubClient::new("1"));
205 let clone1 = original1.clone();
206 assert_eq!(original1.identity(), "1");
207 assert_eq!(clone1.identity(), "1");
208
209 original1.replace_client(StubClient::new("2"));
210 assert_eq!(original1.identity(), "2");
211 assert_eq!(clone1.identity(), "2");
212
213 let original2 = SharedReplaceableClient::new(StubClient::new("3"));
214 let clone2 = original2.clone();
215 assert_eq!(original2.identity(), "3");
216 assert_eq!(clone2.identity(), "3");
217
218 clone2.replace_client(StubClient::new("4"));
219 assert_eq!(original2.identity(), "4");
220 assert_eq!(clone2.identity(), "4");
221 assert_eq!(original1.identity(), "2");
222 assert_eq!(clone1.identity(), "2");
223 }
224
225 #[test]
226 fn client_replaced_from_multiple_threads() {
227 let mut client = SharedReplaceableClient::new(StubClient::new("original"));
228 std::thread::scope(|scope| {
229 for thread_no in 0..100 {
230 let mut client = client.clone();
231 scope.spawn(move || {
232 for i in 0..1000 {
233 let old_generation = client.cloned_generation;
234 client.inner_mut_refreshed();
235 let current_generation = client.cloned_generation;
236 assert!(current_generation >= old_generation);
237 let replace_identity = format!("{thread_no}-{i}");
238 client.replace_client(StubClient::new(&replace_identity));
239 client.inner_mut_refreshed();
240 assert!(client.cloned_generation > current_generation);
241 let refreshed_identity = client.identity();
242 if refreshed_identity.split('-').next().unwrap() == thread_no.to_string() {
243 assert_eq!(replace_identity, refreshed_identity);
244 }
245 }
246 });
247 }
248 });
249 client.inner_mut_refreshed();
250 assert_eq!(client.cloned_generation, 100_000);
251 assert!(client.identity().ends_with("-999"));
252 }
253}