squads_temporal_client/
replaceable.rs

1use crate::NamespacedClient;
2use std::{
3    borrow::Cow,
4    sync::{
5        Arc, RwLock,
6        atomic::{AtomicU32, Ordering},
7    },
8};
9
10/// A client wrapper that allows replacing the underlying client at a later point in time.
11/// Clones of this struct have a shared reference to the underlying client, and each clone also
12/// has its own cached clone of the underlying client. Before every service call, a check is made
13/// whether the shared client was replaced, and the cached clone is updated accordingly.
14///
15/// This struct is fully thread-safe, and it works in a lock-free manner except when the client is
16/// being replaced. A read-write lock is used then, with minimal locking time.
17#[derive(Debug)]
18pub struct SharedReplaceableClient<C>
19where
20    C: Clone + Send + Sync,
21{
22    shared_data: Arc<SharedClientData<C>>,
23    cloned_client: C,
24    cloned_generation: u32,
25}
26
27#[derive(Debug)]
28struct SharedClientData<C>
29where
30    C: Clone + Send + Sync,
31{
32    client: RwLock<C>,
33    generation: AtomicU32,
34}
35
36impl<C> SharedClientData<C>
37where
38    C: Clone + Send + Sync,
39{
40    fn fetch(&self) -> (C, u32) {
41        let lock = self.client.read().unwrap();
42        let client = lock.clone();
43        // Loading generation under lock to ensure the client won't be updated in the meantime.
44        let generation = self.generation.load(Ordering::Acquire);
45        (client, generation)
46    }
47
48    fn fetch_newer_than(&self, current_generation: u32) -> Option<(C, u32)> {
49        // fetch() will do a second atomic load, but it's necessary to avoid a race condition.
50        (current_generation != self.generation.load(Ordering::Acquire)).then(|| self.fetch())
51    }
52
53    fn replace_client(&self, client: C) {
54        let mut lock = self.client.write().unwrap();
55        *lock = client;
56        // Updating generation under lock to guarantee consistency when multiple threads replace the
57        // client at the same time. The client stored last is always the one with latest generation.
58        self.generation.fetch_add(1, Ordering::AcqRel);
59    }
60}
61
62impl<C> SharedReplaceableClient<C>
63where
64    C: Clone + Send + Sync,
65{
66    /// Creates the initial instance of replaceable client with the provided underlying client.
67    /// Use [`clone()`](Self::clone) method to create more instances that share the same underlying client.
68    pub fn new(client: C) -> Self {
69        let cloned_client = client.clone();
70        Self {
71            shared_data: Arc::new(SharedClientData {
72                client: RwLock::new(client),
73                generation: AtomicU32::new(0),
74            }),
75            cloned_client,
76            cloned_generation: 0,
77        }
78    }
79
80    /// Replaces the client for all instances that share this instance's underlying client.
81    pub fn replace_client(&self, new_client: C) {
82        self.shared_data.replace_client(new_client); // cloned_client will be updated on next mutable call
83    }
84
85    /// Returns a clone of the underlying client.
86    pub fn inner_clone(&self) -> C {
87        self.inner_cow().into_owned()
88    }
89
90    /// Returns an immutable reference to this instance's cached clone of the underlying client if
91    /// it's up to date, or a fresh clone of the shared client otherwise. Because it's an immutable
92    /// method, it will not update this instance's cached clone. For this reason, prefer to use
93    /// [`inner_mut_refreshed()`](Self::inner_mut_refreshed) when possible.
94    pub fn inner_cow(&self) -> Cow<'_, C> {
95        self.shared_data
96            .fetch_newer_than(self.cloned_generation)
97            .map(|(c, _)| Cow::Owned(c))
98            .unwrap_or_else(|| Cow::Borrowed(&self.cloned_client))
99    }
100
101    /// Returns a mutable reference to this instance's cached clone of the underlying client. If the
102    /// cached clone is not up to date, it's refreshed before the reference is returned. This method
103    /// is called automatically by most other mutable methods, in particular by all service calls,
104    /// so most of the time it doesn't need to be called directly.
105    ///
106    /// While this method allows mutable access to the underlying client, any configuration changes
107    /// will not be shared with other instances, and will be lost if the client gets replaced from
108    /// anywhere. To make configuration changes, use [`replace_client()`](Self::replace_client) instead.
109    pub fn inner_mut_refreshed(&mut self) -> &mut C {
110        if let Some((client, generation)) =
111            self.shared_data.fetch_newer_than(self.cloned_generation)
112        {
113            self.cloned_client = client;
114            self.cloned_generation = generation;
115        }
116        &mut self.cloned_client
117    }
118}
119
120impl<C> Clone for SharedReplaceableClient<C>
121where
122    C: Clone + Send + Sync,
123{
124    /// Creates a new instance of replaceable client that shares the underlying client with this
125    /// instance. Replacing a client in either instance will replace it for both instances, and all
126    /// other clones too.
127    fn clone(&self) -> Self {
128        // self's cloned_client could've been modified through a mutable reference,
129        // so for consistent behavior, we need to fetch it from shared_data.
130        let (client, generation) = self.shared_data.fetch();
131        Self {
132            shared_data: self.shared_data.clone(),
133            cloned_client: client,
134            cloned_generation: generation,
135        }
136    }
137}
138
139impl<C> NamespacedClient for SharedReplaceableClient<C>
140where
141    C: NamespacedClient + Clone + Send + Sync,
142{
143    fn namespace(&self) -> String {
144        self.inner_cow().namespace()
145    }
146
147    fn identity(&self) -> String {
148        self.inner_cow().identity()
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::NamespacedClient;
156    use std::borrow::Cow;
157
158    #[derive(Debug, Clone)]
159    struct StubClient {
160        identity: String,
161    }
162
163    impl StubClient {
164        fn new(identity: &str) -> Self {
165            Self {
166                identity: identity.to_owned(),
167            }
168        }
169    }
170
171    impl NamespacedClient for StubClient {
172        fn namespace(&self) -> String {
173            "default".into()
174        }
175
176        fn identity(&self) -> String {
177            self.identity.clone()
178        }
179    }
180
181    #[test]
182    fn cow_returns_reference_before_and_clone_after_refresh() {
183        let mut client = SharedReplaceableClient::new(StubClient::new("1"));
184        let Cow::Borrowed(inner) = client.inner_cow() else {
185            panic!("expected borrowed inner");
186        };
187        assert_eq!(inner.identity, "1");
188
189        client.replace_client(StubClient::new("2"));
190        let Cow::Owned(inner) = client.inner_cow() else {
191            panic!("expected owned inner");
192        };
193        assert_eq!(inner.identity, "2");
194
195        assert_eq!(client.inner_mut_refreshed().identity, "2");
196        let Cow::Borrowed(inner) = client.inner_cow() else {
197            panic!("expected borrowed inner");
198        };
199        assert_eq!(inner.identity, "2");
200    }
201
202    #[test]
203    fn client_replaced_in_clones() {
204        let original1 = SharedReplaceableClient::new(StubClient::new("1"));
205        let clone1 = original1.clone();
206        assert_eq!(original1.identity(), "1");
207        assert_eq!(clone1.identity(), "1");
208
209        original1.replace_client(StubClient::new("2"));
210        assert_eq!(original1.identity(), "2");
211        assert_eq!(clone1.identity(), "2");
212
213        let original2 = SharedReplaceableClient::new(StubClient::new("3"));
214        let clone2 = original2.clone();
215        assert_eq!(original2.identity(), "3");
216        assert_eq!(clone2.identity(), "3");
217
218        clone2.replace_client(StubClient::new("4"));
219        assert_eq!(original2.identity(), "4");
220        assert_eq!(clone2.identity(), "4");
221        assert_eq!(original1.identity(), "2");
222        assert_eq!(clone1.identity(), "2");
223    }
224
225    #[test]
226    fn client_replaced_from_multiple_threads() {
227        let mut client = SharedReplaceableClient::new(StubClient::new("original"));
228        std::thread::scope(|scope| {
229            for thread_no in 0..100 {
230                let mut client = client.clone();
231                scope.spawn(move || {
232                    for i in 0..1000 {
233                        let old_generation = client.cloned_generation;
234                        client.inner_mut_refreshed();
235                        let current_generation = client.cloned_generation;
236                        assert!(current_generation >= old_generation);
237                        let replace_identity = format!("{thread_no}-{i}");
238                        client.replace_client(StubClient::new(&replace_identity));
239                        client.inner_mut_refreshed();
240                        assert!(client.cloned_generation > current_generation);
241                        let refreshed_identity = client.identity();
242                        if refreshed_identity.split('-').next().unwrap() == thread_no.to_string() {
243                            assert_eq!(replace_identity, refreshed_identity);
244                        }
245                    }
246                });
247            }
248        });
249        client.inner_mut_refreshed();
250        assert_eq!(client.cloned_generation, 100_000);
251        assert!(client.identity().ends_with("-999"));
252    }
253}