Skip to main content

tryaudex_core/
rotation.rs

1use std::path::PathBuf;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5
6use crate::credentials::TempCredentials;
7use crate::error::Result;
8
9/// Configuration for automatic credential rotation.
10#[derive(Debug, Clone)]
11pub struct RotationConfig {
12    /// Rotate credentials this many seconds before they expire.
13    pub rotate_before_secs: u64,
14    /// Minimum session TTL (in seconds) to enable rotation.
15    /// Sessions shorter than this use static credentials.
16    pub min_ttl_for_rotation_secs: u64,
17}
18
19impl Default for RotationConfig {
20    fn default() -> Self {
21        Self {
22            rotate_before_secs: 300,        // 5 minutes before expiry
23            min_ttl_for_rotation_secs: 900, // only rotate if TTL >= 15 minutes
24        }
25    }
26}
27
28/// A file that holds provider credentials, updated atomically.
29/// Automatically cleaned up when dropped.
30pub struct CredentialFile {
31    path: PathBuf,
32}
33
34impl CredentialFile {
35    /// Create a new credential file in a temp directory.
36    pub fn new() -> Result<Self> {
37        let dir = std::env::temp_dir().join("audex");
38        std::fs::create_dir_all(&dir)?;
39        let path = dir.join(format!("creds-{}.ini", uuid::Uuid::new_v4()));
40        Ok(Self { path })
41    }
42
43    /// Path to the credential file.
44    pub fn path(&self) -> &std::path::Path {
45        &self.path
46    }
47
48    /// Write credentials in AWS shared credentials file format.
49    /// Uses atomic write (write to tmp, then rename) to avoid partial reads.
50    pub fn write_aws(&self, creds: &TempCredentials) -> Result<()> {
51        let content = format!(
52            "[default]\naws_access_key_id = {}\naws_secret_access_key = {}\naws_session_token = {}\n",
53            creds.access_key_id, creds.secret_access_key, creds.session_token
54        );
55        atomic_write(&self.path, content.as_bytes())
56    }
57
58    /// Write a GCP access token to the credential file.
59    pub fn write_gcp(&self, token: &str) -> Result<()> {
60        atomic_write(&self.path, token.as_bytes())
61    }
62
63    /// Write an Azure access token to the credential file.
64    pub fn write_azure(&self, token: &str) -> Result<()> {
65        atomic_write(&self.path, token.as_bytes())
66    }
67
68    /// Remove the credential file from disk.
69    pub fn cleanup(&self) {
70        let _ = std::fs::remove_file(&self.path);
71        let _ = std::fs::remove_file(self.path.with_extension("tmp"));
72    }
73}
74
75impl Drop for CredentialFile {
76    fn drop(&mut self) {
77        self.cleanup();
78    }
79}
80
81/// Atomic write: write to a .tmp sibling then rename over the target.
82fn atomic_write(path: &std::path::Path, data: &[u8]) -> Result<()> {
83    let tmp = path.with_extension("tmp");
84    std::fs::write(&tmp, data)?;
85    std::fs::rename(&tmp, path)?;
86    Ok(())
87}
88
89/// Handle to a running credential rotation background thread.
90/// Dropping the handle signals the thread to stop.
91pub struct RotationHandle {
92    stop: Arc<AtomicBool>,
93    thread: Option<std::thread::JoinHandle<()>>,
94}
95
96impl RotationHandle {
97    /// Signal the rotation thread to stop and wait for it to finish.
98    pub fn stop(mut self) {
99        self.stop.store(true, Ordering::Relaxed);
100        if let Some(handle) = self.thread.take() {
101            let _ = handle.join();
102        }
103    }
104}
105
106impl Drop for RotationHandle {
107    fn drop(&mut self) {
108        self.stop.store(true, Ordering::Relaxed);
109        if let Some(handle) = self.thread.take() {
110            let _ = handle.join();
111        }
112    }
113}
114
115/// Start a background credential rotation thread.
116///
117/// The `refresh_fn` closure is called whenever new credentials are needed.
118/// It runs inside a dedicated tokio runtime on the background thread.
119/// The rotation loop checks every 10 seconds whether it's time to rotate.
120pub fn start_rotation<F>(
121    cred_file: Arc<CredentialFile>,
122    initial_expires_at: chrono::DateTime<chrono::Utc>,
123    rotate_before: Duration,
124    refresh_fn: F,
125) -> RotationHandle
126where
127    F: Fn() -> std::result::Result<TempCredentials, String> + Send + 'static,
128{
129    let stop = Arc::new(AtomicBool::new(false));
130    let stop_clone = stop.clone();
131
132    let thread = std::thread::Builder::new()
133        .name("audex-rotation".into())
134        .spawn(move || {
135            let rotate_chrono =
136                chrono::Duration::from_std(rotate_before).unwrap_or(chrono::Duration::seconds(300));
137            let mut next_rotation = initial_expires_at - rotate_chrono;
138
139            loop {
140                if stop_clone.load(Ordering::Relaxed) {
141                    break;
142                }
143
144                let now = chrono::Utc::now();
145                if now >= next_rotation {
146                    tracing::info!("Credential rotation triggered");
147                    match refresh_fn() {
148                        Ok(new_creds) => {
149                            match cred_file.write_aws(&new_creds) {
150                                Ok(()) => {
151                                    tracing::info!(
152                                        "Credentials rotated successfully, new expiry: {}",
153                                        new_creds.expires_at
154                                    );
155                                    next_rotation = new_creds.expires_at - rotate_chrono;
156                                }
157                                Err(e) => {
158                                    tracing::error!("Failed to write rotated credentials: {}", e);
159                                    // Retry in 30 seconds
160                                    next_rotation = now + chrono::Duration::seconds(30);
161                                }
162                            }
163                        }
164                        Err(e) => {
165                            tracing::error!("Failed to refresh credentials: {}", e);
166                            // Retry in 30 seconds
167                            next_rotation = now + chrono::Duration::seconds(30);
168                        }
169                    }
170                }
171
172                // Check every 10 seconds
173                std::thread::sleep(Duration::from_secs(10));
174            }
175
176            tracing::debug!("Credential rotation thread stopped");
177        })
178        .expect("failed to spawn rotation thread");
179
180    RotationHandle {
181        stop,
182        thread: Some(thread),
183    }
184}
185
186/// Check if rotation should be enabled for a given TTL.
187pub fn should_rotate(ttl: Duration, config: &RotationConfig) -> bool {
188    ttl.as_secs() >= config.min_ttl_for_rotation_secs
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_should_rotate_long_ttl() {
197        let config = RotationConfig::default();
198        assert!(should_rotate(Duration::from_secs(3600), &config)); // 1 hour
199        assert!(should_rotate(Duration::from_secs(900), &config)); // exactly 15 min
200    }
201
202    #[test]
203    fn test_should_not_rotate_short_ttl() {
204        let config = RotationConfig::default();
205        assert!(!should_rotate(Duration::from_secs(600), &config)); // 10 min
206        assert!(!should_rotate(Duration::from_secs(300), &config)); // 5 min
207        assert!(!should_rotate(Duration::from_secs(60), &config)); // 1 min
208    }
209
210    #[test]
211    fn test_credential_file_write_aws() {
212        let cred_file = CredentialFile::new().unwrap();
213        let creds = TempCredentials {
214            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
215            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
216            session_token: "FwoGZXtoken123".to_string(),
217            expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
218        };
219        cred_file.write_aws(&creds).unwrap();
220
221        let content = std::fs::read_to_string(cred_file.path()).unwrap();
222        assert!(content.contains("[default]"));
223        assert!(content.contains("AKIAIOSFODNN7EXAMPLE"));
224        assert!(content.contains("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"));
225        assert!(content.contains("FwoGZXtoken123"));
226    }
227
228    #[test]
229    fn test_credential_file_write_gcp() {
230        let cred_file = CredentialFile::new().unwrap();
231        cred_file.write_gcp("ya29.test-token-123").unwrap();
232
233        let content = std::fs::read_to_string(cred_file.path()).unwrap();
234        assert_eq!(content, "ya29.test-token-123");
235    }
236
237    #[test]
238    fn test_credential_file_cleanup_on_drop() {
239        let path;
240        {
241            let cred_file = CredentialFile::new().unwrap();
242            cred_file.write_gcp("test").unwrap();
243            path = cred_file.path().to_path_buf();
244            assert!(path.exists());
245        }
246        // After drop, file should be cleaned up
247        assert!(!path.exists());
248    }
249
250    #[test]
251    fn test_rotation_handle_stop() {
252        use std::sync::atomic::AtomicU32;
253
254        let cred_file = Arc::new(CredentialFile::new().unwrap());
255        let call_count = Arc::new(AtomicU32::new(0));
256        let call_count_clone = call_count.clone();
257
258        // Set expiry in the past so rotation triggers immediately
259        let expires_at = chrono::Utc::now() - chrono::Duration::seconds(10);
260
261        let initial_creds = TempCredentials {
262            access_key_id: "AKIATEST".to_string(),
263            secret_access_key: "secret".to_string(),
264            session_token: "token".to_string(),
265            expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
266        };
267
268        // Write initial credentials so the file exists
269        cred_file.write_aws(&initial_creds).unwrap();
270
271        let handle = start_rotation(
272            cred_file.clone(),
273            expires_at,
274            Duration::from_secs(60),
275            move || {
276                call_count_clone.fetch_add(1, Ordering::Relaxed);
277                Ok(TempCredentials {
278                    access_key_id: "AKIAROTATED".to_string(),
279                    secret_access_key: "new-secret".to_string(),
280                    session_token: "new-token".to_string(),
281                    expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
282                })
283            },
284        );
285
286        // Give the rotation thread time to fire
287        std::thread::sleep(Duration::from_millis(500));
288        handle.stop();
289
290        // refresh_fn should have been called at least once
291        assert!(call_count.load(Ordering::Relaxed) >= 1);
292
293        // Credential file should have rotated credentials
294        let content = std::fs::read_to_string(cred_file.path()).unwrap();
295        assert!(content.contains("AKIAROTATED"));
296    }
297
298    #[test]
299    fn test_rotation_config_defaults() {
300        let config = RotationConfig::default();
301        assert_eq!(config.rotate_before_secs, 300);
302        assert_eq!(config.min_ttl_for_rotation_secs, 900);
303    }
304
305    #[test]
306    fn test_atomic_write_is_atomic() {
307        let dir = std::env::temp_dir().join("audex-test-atomic");
308        let _ = std::fs::create_dir_all(&dir);
309        let path = dir.join("test-atomic.txt");
310
311        atomic_write(&path, b"hello world").unwrap();
312        assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello world");
313
314        // Overwrite
315        atomic_write(&path, b"updated").unwrap();
316        assert_eq!(std::fs::read_to_string(&path).unwrap(), "updated");
317
318        // tmp file should not exist
319        assert!(!path.with_extension("tmp").exists());
320
321        let _ = std::fs::remove_file(&path);
322    }
323}