temporalio_client/
replaceable.rs1use crate::NamespacedClient;
2use std::{
3 borrow::Cow,
4 sync::{
5 Arc, RwLock,
6 atomic::{AtomicU32, Ordering},
7 },
8};
9
10#[derive(Debug)]
18pub struct SharedReplaceableClient<C>
19where
20 C: Clone + Send + Sync,
21{
22 shared_data: Arc<SharedClientData<C>>,
23 pub(crate) cloned_client: C,
24 pub(crate) cloned_generation: u32,
25}
26
27#[derive(Debug)]
28struct SharedClientData<C>
29where
30 C: Clone + Send + Sync,
31{
32 client: RwLock<C>,
33 generation: AtomicU32,
34}
35
36impl<C> SharedClientData<C>
37where
38 C: Clone + Send + Sync,
39{
40 fn fetch(&self) -> (C, u32) {
41 let lock = self.client.read().unwrap();
42 let client = lock.clone();
43 let generation = self.generation.load(Ordering::Acquire);
45 (client, generation)
46 }
47
48 fn fetch_newer_than(&self, current_generation: u32) -> Option<(C, u32)> {
49 (current_generation != self.generation.load(Ordering::Acquire)).then(|| self.fetch())
51 }
52
53 fn replace_client(&self, client: C) {
54 let mut lock = self.client.write().unwrap();
55 *lock = client;
56 self.generation.fetch_add(1, Ordering::AcqRel);
59 }
60}
61
62impl<C> SharedReplaceableClient<C>
63where
64 C: Clone + Send + Sync,
65{
66 pub fn new(client: C) -> Self {
69 let cloned_client = client.clone();
70 Self {
71 shared_data: Arc::new(SharedClientData {
72 client: RwLock::new(client),
73 generation: AtomicU32::new(0),
74 }),
75 cloned_client,
76 cloned_generation: 0,
77 }
78 }
79
80 pub fn replace_client(&self, new_client: C) {
82 self.shared_data.replace_client(new_client); }
84
85 #[allow(dead_code)]
87 pub fn inner_clone(&self) -> C {
88 self.inner_cow().into_owned()
89 }
90
91 pub fn inner_cow(&self) -> Cow<'_, C> {
96 self.shared_data
97 .fetch_newer_than(self.cloned_generation)
98 .map(|(c, _)| Cow::Owned(c))
99 .unwrap_or_else(|| Cow::Borrowed(&self.cloned_client))
100 }
101
102 pub fn inner_mut_refreshed(&mut self) -> &mut C {
112 if let Some((client, generation)) =
113 self.shared_data.fetch_newer_than(self.cloned_generation)
114 {
115 self.cloned_client = client;
116 self.cloned_generation = generation;
117 }
118 &mut self.cloned_client
119 }
120}
121
122impl<C> Clone for SharedReplaceableClient<C>
123where
124 C: Clone + Send + Sync,
125{
126 fn clone(&self) -> Self {
130 let (client, generation) = self.shared_data.fetch();
133 Self {
134 shared_data: self.shared_data.clone(),
135 cloned_client: client,
136 cloned_generation: generation,
137 }
138 }
139}
140
141impl<C> NamespacedClient for SharedReplaceableClient<C>
142where
143 C: NamespacedClient + Clone + Send + Sync,
144{
145 fn namespace(&self) -> String {
146 self.inner_cow().namespace()
147 }
148
149 fn identity(&self) -> String {
150 self.inner_cow().identity()
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use std::borrow::Cow;
158
159 #[derive(Debug, Clone)]
160 struct StubClient {
161 identity: String,
162 }
163
164 impl StubClient {
165 fn new(identity: &str) -> Self {
166 Self {
167 identity: identity.to_owned(),
168 }
169 }
170 }
171
172 impl NamespacedClient for StubClient {
173 fn namespace(&self) -> String {
174 "default".into()
175 }
176
177 fn identity(&self) -> String {
178 self.identity.clone()
179 }
180 }
181
182 #[test]
183 fn cow_returns_reference_before_and_clone_after_refresh() {
184 let mut client = SharedReplaceableClient::new(StubClient::new("1"));
185 let Cow::Borrowed(inner) = client.inner_cow() else {
186 panic!("expected borrowed inner");
187 };
188 assert_eq!(inner.identity, "1");
189
190 client.replace_client(StubClient::new("2"));
191 let Cow::Owned(inner) = client.inner_cow() else {
192 panic!("expected owned inner");
193 };
194 assert_eq!(inner.identity, "2");
195
196 assert_eq!(client.inner_mut_refreshed().identity, "2");
197 let Cow::Borrowed(inner) = client.inner_cow() else {
198 panic!("expected borrowed inner");
199 };
200 assert_eq!(inner.identity, "2");
201 }
202
203 #[test]
204 fn client_replaced_in_clones() {
205 let original1 = SharedReplaceableClient::new(StubClient::new("1"));
206 let clone1 = original1.clone();
207 assert_eq!(original1.identity(), "1");
208 assert_eq!(clone1.identity(), "1");
209
210 original1.replace_client(StubClient::new("2"));
211 assert_eq!(original1.identity(), "2");
212 assert_eq!(clone1.identity(), "2");
213
214 let original2 = SharedReplaceableClient::new(StubClient::new("3"));
215 let clone2 = original2.clone();
216 assert_eq!(original2.identity(), "3");
217 assert_eq!(clone2.identity(), "3");
218
219 clone2.replace_client(StubClient::new("4"));
220 assert_eq!(original2.identity(), "4");
221 assert_eq!(clone2.identity(), "4");
222 assert_eq!(original1.identity(), "2");
223 assert_eq!(clone1.identity(), "2");
224 }
225
226 #[test]
227 fn client_replaced_from_multiple_threads() {
228 let mut client = SharedReplaceableClient::new(StubClient::new("original"));
229 std::thread::scope(|scope| {
230 for thread_no in 0..100 {
231 let mut client = client.clone();
232 scope.spawn(move || {
233 for i in 0..1000 {
234 let old_generation = client.cloned_generation;
235 client.inner_mut_refreshed();
236 let current_generation = client.cloned_generation;
237 assert!(current_generation >= old_generation);
238 let replace_identity = format!("{thread_no}-{i}");
239 client.replace_client(StubClient::new(&replace_identity));
240 client.inner_mut_refreshed();
241 assert!(client.cloned_generation > current_generation);
242 let refreshed_identity = client.identity();
243 if refreshed_identity.split('-').next().unwrap() == thread_no.to_string() {
244 assert_eq!(replace_identity, refreshed_identity);
245 }
246 }
247 });
248 }
249 });
250 client.inner_mut_refreshed();
251 assert_eq!(client.cloned_generation, 100_000);
252 assert!(client.identity().ends_with("-999"));
253 }
254}