Skip to main content

temporalio_client/
replaceable.rs

1use crate::NamespacedClient;
2use std::{
3    borrow::Cow,
4    sync::{
5        Arc, RwLock,
6        atomic::{AtomicU32, Ordering},
7    },
8};
9
10/// A client wrapper that allows replacing the underlying client at a later point in time.
11/// Clones of this struct have a shared reference to the underlying client, and each clone also
12/// has its own cached clone of the underlying client. Before every service call, a check is made
13/// whether the shared client was replaced, and the cached clone is updated accordingly.
14///
15/// This struct is fully thread-safe, and it works in a lock-free manner except when the client is
16/// being replaced. A read-write lock is used then, with minimal locking time.
17#[derive(Debug)]
18pub struct SharedReplaceableClient<C>
19where
20    C: Clone + Send + Sync,
21{
22    shared_data: Arc<SharedClientData<C>>,
23    pub(crate) cloned_client: C,
24    pub(crate) cloned_generation: u32,
25}
26
27#[derive(Debug)]
28struct SharedClientData<C>
29where
30    C: Clone + Send + Sync,
31{
32    client: RwLock<C>,
33    generation: AtomicU32,
34}
35
36impl<C> SharedClientData<C>
37where
38    C: Clone + Send + Sync,
39{
40    fn fetch(&self) -> (C, u32) {
41        let lock = self.client.read().unwrap();
42        let client = lock.clone();
43        // Loading generation under lock to ensure the client won't be updated in the meantime.
44        let generation = self.generation.load(Ordering::Acquire);
45        (client, generation)
46    }
47
48    fn fetch_newer_than(&self, current_generation: u32) -> Option<(C, u32)> {
49        // fetch() will do a second atomic load, but it's necessary to avoid a race condition.
50        (current_generation != self.generation.load(Ordering::Acquire)).then(|| self.fetch())
51    }
52
53    fn replace_client(&self, client: C) {
54        let mut lock = self.client.write().unwrap();
55        *lock = client;
56        // Updating generation under lock to guarantee consistency when multiple threads replace the
57        // client at the same time. The client stored last is always the one with latest generation.
58        self.generation.fetch_add(1, Ordering::AcqRel);
59    }
60}
61
62impl<C> SharedReplaceableClient<C>
63where
64    C: Clone + Send + Sync,
65{
66    /// Creates the initial instance of replaceable client with the provided underlying client.
67    /// Use [`clone()`](Self::clone) method to create more instances that share the same underlying client.
68    pub fn new(client: C) -> Self {
69        let cloned_client = client.clone();
70        Self {
71            shared_data: Arc::new(SharedClientData {
72                client: RwLock::new(client),
73                generation: AtomicU32::new(0),
74            }),
75            cloned_client,
76            cloned_generation: 0,
77        }
78    }
79
80    /// Replaces the client for all instances that share this instance's underlying client.
81    pub fn replace_client(&self, new_client: C) {
82        self.shared_data.replace_client(new_client); // cloned_client will be updated on next mutable call
83    }
84
85    /// Returns a clone of the underlying client.
86    #[allow(dead_code)]
87    pub fn inner_clone(&self) -> C {
88        self.inner_cow().into_owned()
89    }
90
91    /// Returns an immutable reference to this instance's cached clone of the underlying client if
92    /// it's up to date, or a fresh clone of the shared client otherwise. Because it's an immutable
93    /// method, it will not update this instance's cached clone. For this reason, prefer to use
94    /// [`inner_mut_refreshed()`](Self::inner_mut_refreshed) when possible.
95    pub fn inner_cow(&self) -> Cow<'_, C> {
96        self.shared_data
97            .fetch_newer_than(self.cloned_generation)
98            .map(|(c, _)| Cow::Owned(c))
99            .unwrap_or_else(|| Cow::Borrowed(&self.cloned_client))
100    }
101
102    /// Returns a mutable reference to this instance's cached clone of the underlying client. If the
103    /// cached clone is not up to date, it's refreshed before the reference is returned. This method
104    /// is called automatically by most other mutable methods, in particular by all service calls,
105    /// so most of the time it doesn't need to be called directly.
106    ///
107    /// While this method allows mutable access to the underlying client, any configuration changes
108    /// will not be shared with other instances, and will be lost if the client gets replaced from
109    /// anywhere. To make configuration changes, use [`replace_client()`](Self::replace_client)
110    /// instead.
111    pub fn inner_mut_refreshed(&mut self) -> &mut C {
112        if let Some((client, generation)) =
113            self.shared_data.fetch_newer_than(self.cloned_generation)
114        {
115            self.cloned_client = client;
116            self.cloned_generation = generation;
117        }
118        &mut self.cloned_client
119    }
120}
121
122impl<C> Clone for SharedReplaceableClient<C>
123where
124    C: Clone + Send + Sync,
125{
126    /// Creates a new instance of replaceable client that shares the underlying client with this
127    /// instance. Replacing a client in either instance will replace it for both instances, and all
128    /// other clones too.
129    fn clone(&self) -> Self {
130        // self's cloned_client could've been modified through a mutable reference,
131        // so for consistent behavior, we need to fetch it from shared_data.
132        let (client, generation) = self.shared_data.fetch();
133        Self {
134            shared_data: self.shared_data.clone(),
135            cloned_client: client,
136            cloned_generation: generation,
137        }
138    }
139}
140
141impl<C> NamespacedClient for SharedReplaceableClient<C>
142where
143    C: NamespacedClient + Clone + Send + Sync,
144{
145    fn namespace(&self) -> String {
146        self.inner_cow().namespace()
147    }
148
149    fn identity(&self) -> String {
150        self.inner_cow().identity()
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use std::borrow::Cow;
158
159    #[derive(Debug, Clone)]
160    struct StubClient {
161        identity: String,
162    }
163
164    impl StubClient {
165        fn new(identity: &str) -> Self {
166            Self {
167                identity: identity.to_owned(),
168            }
169        }
170    }
171
172    impl NamespacedClient for StubClient {
173        fn namespace(&self) -> String {
174            "default".into()
175        }
176
177        fn identity(&self) -> String {
178            self.identity.clone()
179        }
180    }
181
182    #[test]
183    fn cow_returns_reference_before_and_clone_after_refresh() {
184        let mut client = SharedReplaceableClient::new(StubClient::new("1"));
185        let Cow::Borrowed(inner) = client.inner_cow() else {
186            panic!("expected borrowed inner");
187        };
188        assert_eq!(inner.identity, "1");
189
190        client.replace_client(StubClient::new("2"));
191        let Cow::Owned(inner) = client.inner_cow() else {
192            panic!("expected owned inner");
193        };
194        assert_eq!(inner.identity, "2");
195
196        assert_eq!(client.inner_mut_refreshed().identity, "2");
197        let Cow::Borrowed(inner) = client.inner_cow() else {
198            panic!("expected borrowed inner");
199        };
200        assert_eq!(inner.identity, "2");
201    }
202
203    #[test]
204    fn client_replaced_in_clones() {
205        let original1 = SharedReplaceableClient::new(StubClient::new("1"));
206        let clone1 = original1.clone();
207        assert_eq!(original1.identity(), "1");
208        assert_eq!(clone1.identity(), "1");
209
210        original1.replace_client(StubClient::new("2"));
211        assert_eq!(original1.identity(), "2");
212        assert_eq!(clone1.identity(), "2");
213
214        let original2 = SharedReplaceableClient::new(StubClient::new("3"));
215        let clone2 = original2.clone();
216        assert_eq!(original2.identity(), "3");
217        assert_eq!(clone2.identity(), "3");
218
219        clone2.replace_client(StubClient::new("4"));
220        assert_eq!(original2.identity(), "4");
221        assert_eq!(clone2.identity(), "4");
222        assert_eq!(original1.identity(), "2");
223        assert_eq!(clone1.identity(), "2");
224    }
225
226    #[test]
227    fn client_replaced_from_multiple_threads() {
228        let mut client = SharedReplaceableClient::new(StubClient::new("original"));
229        std::thread::scope(|scope| {
230            for thread_no in 0..100 {
231                let mut client = client.clone();
232                scope.spawn(move || {
233                    for i in 0..1000 {
234                        let old_generation = client.cloned_generation;
235                        client.inner_mut_refreshed();
236                        let current_generation = client.cloned_generation;
237                        assert!(current_generation >= old_generation);
238                        let replace_identity = format!("{thread_no}-{i}");
239                        client.replace_client(StubClient::new(&replace_identity));
240                        client.inner_mut_refreshed();
241                        assert!(client.cloned_generation > current_generation);
242                        let refreshed_identity = client.identity();
243                        if refreshed_identity.split('-').next().unwrap() == thread_no.to_string() {
244                            assert_eq!(replace_identity, refreshed_identity);
245                        }
246                    }
247                });
248            }
249        });
250        client.inner_mut_refreshed();
251        assert_eq!(client.cloned_generation, 100_000);
252        assert!(client.identity().ends_with("-999"));
253    }
254}