Skip to main content

secret_manager/
rotator.rs

1use crate::encryptor::{Encrypted, KeyEncryptor};
2use async_trait::async_trait;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime};
5use tokio_util::sync::CancellationToken;
6use tracing::{error, info};
7
8// ---------------------------------------------------------------------------
9// Constants
10// ---------------------------------------------------------------------------
11
12/// Back-off on backend errors before retrying the full loop.
13const ERROR_RETRY_DELAY: Duration = Duration::from_secs(30);
14
15// ---------------------------------------------------------------------------
16// SecretRotationBackend — write-side trait
17// ---------------------------------------------------------------------------
18
19/// Write-side storage contract required by [`KeyRotator`].
20///
21/// Implement this trait (together with [`SecretBackend`](crate::SecretBackend) if you also need
22/// reading) to bring your own backend.  The two methods together form an optimistic-locking
23/// protocol: read the latest version, then attempt a conditional insert.
24#[async_trait]
25pub trait SecretRotationBackend: Send + Sync + 'static {
26    /// The error type returned on backend failures.
27    type Error: std::error::Error + Send + Sync + 'static;
28
29    /// Returns `(version, activated_at)` of the most recently **inserted** key for `group_id`,
30    /// or `None` when no key exists yet.
31    async fn latest_key_info(
32        &self,
33        group_id: &str,
34    ) -> Result<Option<(u8, SystemTime)>, Self::Error>;
35
36    /// Atomically inserts a new key only when the current version still equals
37    /// `expected_version` (use `None` when no key exists yet).
38    ///
39    /// Returns `true` if the key was inserted, `false` if another instance raced ahead and
40    /// the version no longer matches.  Implementations should acquire an advisory lock or use
41    /// a compare-and-swap so that concurrent rotators converge safely.
42    async fn try_insert_key(
43        &self,
44        group_id: &str,
45        expected_version: Option<u8>,
46        new_version: u8,
47        encrypted: &Encrypted,
48        activated_at: SystemTime,
49    ) -> Result<bool, Self::Error>;
50}
51
52// ---------------------------------------------------------------------------
53// KeyRotator
54// ---------------------------------------------------------------------------
55
56/// Background task that periodically generates and persists a new encryption key.
57///
58/// `KeyRotator` is the **write side** of the key-management system.  It runs a single
59/// perpetual loop: sleep until the current key is due for rotation, generate a new key,
60/// encrypt it, and attempt a conditional insert via [`SecretRotationBackend::try_insert_key`].
61/// If another instance raced ahead the insert is skipped and the loop simply sleeps until the
62/// *new* key expires.
63///
64/// Multiple `KeyRotator` instances for the same `group_id` can run concurrently (e.g. for
65/// high availability); the optimistic-locking protocol in `try_insert_key` ensures only one
66/// insert succeeds per rotation cycle.
67///
68/// # Type parameters
69///
70/// - `B` — backend that implements [`SecretRotationBackend`]
71/// - `E` — encryptor that implements [`KeyEncryptor`]
72/// - `V` — ring buffer size (number of key slots, **must be ≤ 256**, default 256).
73///   Must match the `V` of any [`InMemorySecretGroup`](crate::InMemorySecretGroup) consuming
74///   the keys.
75/// - `S` — key size in bytes (default 32)
76///
77/// # Standalone use
78///
79/// `KeyRotator` can be used without a [`SecretSyncer`](crate::SecretSyncer) or
80/// [`SecretManager`](crate::SecretManager).  This is useful when you want a dedicated
81/// rotation service that writes to shared storage while other nodes only read:
82///
83/// ```rust,no_run
84/// # use secret_manager::*;
85/// # use async_trait::async_trait;
86/// # use std::time::{Duration, SystemTime};
87/// # use tokio_util::sync::CancellationToken;
88/// # struct MyBackend;
89/// # #[async_trait]
90/// # impl SecretRotationBackend for MyBackend {
91/// #     type Error = std::convert::Infallible;
92/// #     async fn latest_key_info(&self, _: &str) -> Result<Option<(u8, SystemTime)>, Self::Error> { Ok(None) }
93/// #     async fn try_insert_key(&self, _: &str, _: Option<u8>, _: u8, _: &Encrypted, _: SystemTime) -> Result<bool, Self::Error> { Ok(true) }
94/// # }
95/// # async fn example() {
96/// # let (backend, encryptor) = (MyBackend, NoOpEncryptor);
97/// let rotator: KeyRotator<_, _, 256, 32> = KeyRotator::new(
98///     "session-tokens",
99///     backend,
100///     Duration::from_secs(3600),
101///     Duration::from_secs(30),
102///     encryptor,
103///     || [0u8; 32],
104/// );
105/// rotator.run(CancellationToken::new()).await;
106/// # }
107/// ```
108pub struct KeyRotator<B: SecretRotationBackend, E: KeyEncryptor + Clone, const V: usize = 256, const S: usize = 32> {
109    group_id: String,
110    backend: B,
111    encryptor: E,
112    rotation_interval: Duration,
113    propagation_delay: Duration,
114    generate_key: Arc<dyn Fn() -> [u8; S] + Send + Sync + 'static>,
115}
116
117impl<B: SecretRotationBackend, E: KeyEncryptor + Clone, const V: usize, const S: usize> KeyRotator<B, E, V, S> {
118    /// Create a new `KeyRotator`.
119    ///
120    /// # Arguments
121    ///
122    /// - `group_id` — identifies the logical key group in storage
123    /// - `backend` — implements [`SecretRotationBackend`]
124    /// - `rotation_interval` — how long a key is valid before a new one is generated
125    /// - `propagation_delay` — added to `SystemTime::now()` to compute `activated_at` for the
126    ///   new key, giving syncers time to pull the key before it becomes active
127    /// - `encryptor` — wraps key bytes before storage
128    /// - `generate_key` — produces fresh key material; defaults in [`SecretManager`](crate::SecretManager)
129    ///   to a CSPRNG fill
130    ///
131    /// # Panics
132    ///
133    /// Panics at compile time if `V > 256` (versions are stored as `u8`).
134    pub fn new(
135        group_id: impl Into<String>,
136        backend: B,
137        rotation_interval: Duration,
138        propagation_delay: Duration,
139        encryptor: E,
140        generate_key: impl Fn() -> [u8; S] + Send + Sync + 'static,
141    ) -> Self {
142        const { assert!(V <= 256, "ring buffer size V must be ≤ 256; versions are u8") };
143        Self {
144            group_id: group_id.into(),
145            backend,
146            encryptor,
147            rotation_interval,
148            propagation_delay,
149            generate_key: Arc::new(generate_key),
150        }
151    }
152
153    /// Run the rotation loop until `token` is cancelled.
154    ///
155    /// This method consumes `self` and runs forever, sleeping between rotations.  Pass the
156    /// returned future to [`tokio::spawn`] or run it directly.  Cancel `token` for a clean
157    /// shutdown; the loop exits after the current sleep or retry delay completes.
158    ///
159    /// On backend or encryption errors the rotator backs off for 30 seconds before retrying,
160    /// so transient failures do not cause a tight error loop.
161    pub async fn run(self, token: CancellationToken) {
162        info!(group_id = %self.group_id, "KeyRotator starting");
163
164        loop {
165            let pre_info = match self.backend.latest_key_info(&self.group_id).await {
166                Ok(info) => info,
167                Err(e) => {
168                    error!(group_id = %self.group_id, error = %e, "KeyRotator: backend error");
169                    if sleep_or_cancel(ERROR_RETRY_DELAY, &token).await {
170                        break;
171                    }
172                    continue;
173                }
174            };
175
176            let sleep_dur = match pre_info {
177                Some((_, last_activated_at)) => last_activated_at
178                    .checked_add(self.rotation_interval)
179                    .and_then(|next| next.duration_since(SystemTime::now()).ok())
180                    .unwrap_or(Duration::ZERO),
181                None => Duration::ZERO,
182            };
183
184            if sleep_or_cancel(sleep_dur, &token).await {
185                break;
186            }
187
188            let expected_version = pre_info.map(|(v, _)| v);
189            let new_version = expected_version
190                .map(|v| ((v as usize + 1) % V) as u8)
191                .unwrap_or(0);
192            let key_bytes = (self.generate_key)();
193            let activated_at = SystemTime::now() + self.propagation_delay;
194
195            let encrypted = match self.encryptor.encrypt(&key_bytes).await {
196                Ok(enc) => enc,
197                Err(e) => {
198                    error!(group_id = %self.group_id, error = %e, "KeyRotator: encryption failed");
199                    if sleep_or_cancel(ERROR_RETRY_DELAY, &token).await {
200                        break;
201                    }
202                    continue;
203                }
204            };
205
206            match self
207                .backend
208                .try_insert_key(
209                    &self.group_id,
210                    expected_version,
211                    new_version,
212                    &encrypted,
213                    activated_at,
214                )
215                .await
216            {
217                Ok(true) => {
218                    info!(group_id = %self.group_id, version = new_version, "KeyRotator: new key inserted")
219                }
220                Ok(false) => {
221                    info!(group_id = %self.group_id, "KeyRotator: another instance rotated")
222                }
223                Err(e) => {
224                    error!(group_id = %self.group_id, error = %e, "KeyRotator: try_insert_key failed");
225                    if sleep_or_cancel(ERROR_RETRY_DELAY, &token).await {
226                        break;
227                    }
228                }
229            }
230        }
231        info!(group_id = %self.group_id, "KeyRotator: shutting down");
232    }
233}
234
235async fn sleep_or_cancel(duration: Duration, token: &CancellationToken) -> bool {
236    tokio::select! {
237        biased;
238        _ = token.cancelled() => true,
239        _ = tokio::time::sleep(duration) => false,
240    }
241}
242
243// ---------------------------------------------------------------------------
244// Tests
245// ---------------------------------------------------------------------------
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use crate::encryptor::Encrypted;
251    use crate::no_op_encryptor::NoOpEncryptor;
252    use std::collections::VecDeque;
253    use std::sync::Mutex;
254
255    #[derive(Debug, PartialEq)]
256    struct MockError;
257    impl std::fmt::Display for MockError {
258        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259            write!(f, "mock error")
260        }
261    }
262    impl std::error::Error for MockError {}
263
264    struct TryInsertCall {
265        expected_version: Option<u8>,
266        new_version: u8,
267        ciphertext: Vec<u8>,
268        activated_at: SystemTime,
269    }
270
271    struct MockRotationBackend {
272        latest_queue: Mutex<VecDeque<Option<(u8, SystemTime)>>>,
273        insert_results: Mutex<VecDeque<Result<bool, MockError>>>,
274        inserted: Arc<Mutex<Vec<TryInsertCall>>>,
275    }
276
277    impl MockRotationBackend {
278        fn new(inserted: Arc<Mutex<Vec<TryInsertCall>>>) -> Self {
279            Self {
280                latest_queue: Mutex::new(VecDeque::new()),
281                insert_results: Mutex::new(VecDeque::new()),
282                inserted,
283            }
284        }
285        fn push_latest(&self, v: Option<(u8, SystemTime)>) {
286            self.latest_queue.lock().unwrap().push_back(v);
287        }
288    }
289
290    #[async_trait]
291    impl SecretRotationBackend for MockRotationBackend {
292        type Error = MockError;
293
294        async fn latest_key_info(
295            &self,
296            _group_id: &str,
297        ) -> Result<Option<(u8, SystemTime)>, MockError> {
298            Ok(self.latest_queue.lock().unwrap().pop_front().flatten())
299        }
300
301        async fn try_insert_key(
302            &self,
303            _group_id: &str,
304            expected_version: Option<u8>,
305            new_version: u8,
306            encrypted: &Encrypted,
307            activated_at: SystemTime,
308        ) -> Result<bool, MockError> {
309            let result = self
310                .insert_results
311                .lock()
312                .unwrap()
313                .pop_front()
314                .unwrap_or(Ok(true));
315            if result == Ok(true) {
316                self.inserted.lock().unwrap().push(TryInsertCall {
317                    expected_version,
318                    new_version,
319                    ciphertext: encrypted.ciphertext.clone(),
320                    activated_at,
321                });
322            }
323            result
324        }
325    }
326
327    #[tokio::test]
328    async fn rotates_immediately_when_no_key_exists() {
329        let inserted = Arc::new(Mutex::new(vec![]));
330        let backend = MockRotationBackend::new(Arc::clone(&inserted));
331        backend.push_latest(None);
332        backend.push_latest(Some((0, SystemTime::now())));
333
334        let rotator: KeyRotator<_, _, 256> = KeyRotator::new(
335            "test-rotator",
336            backend,
337            Duration::from_secs(3600),
338            Duration::from_secs(120),
339            NoOpEncryptor,
340            || [42u8; 32],
341        );
342        let token = CancellationToken::new();
343        let tc = token.clone();
344        let handle = tokio::spawn(async move { rotator.run(tc).await });
345
346        tokio::time::sleep(Duration::from_millis(100)).await;
347        token.cancel();
348        handle.await.unwrap();
349
350        let calls = inserted.lock().unwrap();
351        assert_eq!(calls.len(), 1);
352        assert_eq!(calls[0].expected_version, None);
353        assert_eq!(calls[0].new_version, 0);
354        // NoOpEncryptor passes bytes through as-is
355        assert_eq!(calls[0].ciphertext, vec![42u8; 32]);
356        assert!(calls[0].activated_at > SystemTime::now() + Duration::from_secs(100));
357    }
358}