Skip to main content

sync_auth/
engine.rs

1//! Sync engine — orchestrates bidirectional credential sync.
2
3use crate::backend::{GitBackend, GitRepo};
4use crate::providers;
5use crate::{AuthProvider, CredentialFile, SyncConfig, SyncError, ValidationResult};
6use std::path::{Path, PathBuf};
7use tracing::{debug, info, warn};
8
9/// The sync engine coordinates pulling credentials from a Git repo to local
10/// filesystem and pushing local credentials to the repo.
11pub struct SyncEngine {
12    pub config: SyncConfig,
13    pub backend: Box<dyn GitBackend>,
14    pub providers: Vec<Box<dyn AuthProvider>>,
15}
16
17impl SyncEngine {
18    /// Create a new `SyncEngine` with the given config and default Git backend.
19    pub fn new(config: SyncConfig) -> Result<Self, SyncError> {
20        if config.repo_url.is_empty() {
21            return Err(SyncError::Config("repo_url must not be empty".to_string()));
22        }
23
24        let active_providers = resolve_providers(&config.providers)?;
25        info!(
26            providers = ?active_providers.iter().map(|p| p.name()).collect::<Vec<_>>(),
27            "initialized sync engine"
28        );
29
30        Ok(Self {
31            config,
32            backend: Box::new(GitRepo),
33            providers: active_providers,
34        })
35    }
36
37    /// Create a `SyncEngine` with a custom backend (for testing or alternative storage).
38    pub fn with_backend(
39        config: SyncConfig,
40        backend: Box<dyn GitBackend>,
41    ) -> Result<Self, SyncError> {
42        let active_providers = resolve_providers(&config.providers)?;
43        Ok(Self {
44            config,
45            backend,
46            providers: active_providers,
47        })
48    }
49
50    /// Ensure the sync repo is cloned locally; clone if not.
51    pub async fn ensure_repo(&self) -> Result<(), SyncError> {
52        if self.backend.is_cloned(&self.config.local_path) {
53            debug!(path = %self.config.local_path.display(), "repo already cloned");
54            return Ok(());
55        }
56        self.backend
57            .clone_repo(
58                &self.config.repo_url,
59                &self.config.local_path,
60                self.config.shallow_clone,
61            )
62            .await
63    }
64
65    /// Pull credentials from the remote repo to local filesystem.
66    ///
67    /// 1. Ensure repo is cloned
68    /// 2. Git pull
69    /// 3. Copy files from repo → local credential paths
70    pub async fn pull(&self) -> Result<SyncReport, SyncError> {
71        self.ensure_repo().await?;
72        self.backend.pull(&self.config.local_path).await?;
73
74        let mut report = SyncReport::default();
75
76        for provider in &self.providers {
77            for cred in provider.credential_files() {
78                let repo_path = self.config.local_path.join(&cred.relative_path);
79                if repo_path.exists() {
80                    copy_recursive(&repo_path, &cred.local_path).await?;
81                    report.pulled.push(cred.relative_path.clone());
82                    info!(
83                        provider = provider.name(),
84                        path = %cred.local_path.display(),
85                        "pulled credential"
86                    );
87                } else {
88                    debug!(
89                        provider = provider.name(),
90                        repo_path = %repo_path.display(),
91                        "no credential in repo, skipping"
92                    );
93                }
94            }
95        }
96
97        Ok(report)
98    }
99
100    /// Push local credentials to the remote repo.
101    ///
102    /// 1. Ensure repo is cloned
103    /// 2. Git pull (to avoid conflicts)
104    /// 3. Copy local credential files → repo
105    /// 4. Git commit + push
106    pub async fn push(&self) -> Result<SyncReport, SyncError> {
107        self.ensure_repo().await?;
108
109        // Pull first to minimize conflicts
110        if self.backend.is_cloned(&self.config.local_path) {
111            let _ = self.backend.pull(&self.config.local_path).await;
112        }
113
114        let mut report = SyncReport::default();
115
116        for provider in &self.providers {
117            // Validate before pushing — don't push expired credentials
118            let validation = provider.validate().await;
119            if validation == ValidationResult::Expired {
120                warn!(
121                    provider = provider.name(),
122                    "skipping push: credentials are expired"
123                );
124                report
125                    .skipped
126                    .push(format!("{}: credentials expired", provider.name()));
127                continue;
128            }
129
130            for cred in provider.credential_files() {
131                if cred.local_path.exists() {
132                    let repo_path = self.config.local_path.join(&cred.relative_path);
133                    copy_recursive(&cred.local_path, &repo_path).await?;
134                    report.pushed.push(cred.relative_path.clone());
135                    info!(
136                        provider = provider.name(),
137                        path = %cred.local_path.display(),
138                        "staged credential for push"
139                    );
140                } else {
141                    debug!(
142                        provider = provider.name(),
143                        path = %cred.local_path.display(),
144                        "local credential not found, skipping"
145                    );
146                }
147            }
148        }
149
150        if !report.pushed.is_empty() {
151            let message = format!(
152                "sync-auth: update credentials ({})",
153                report
154                    .pushed
155                    .iter()
156                    .map(String::as_str)
157                    .collect::<Vec<_>>()
158                    .join(", ")
159            );
160            self.backend.push(&self.config.local_path, &message).await?;
161        }
162
163        Ok(report)
164    }
165
166    /// Bidirectional sync: pull then push.
167    pub async fn sync(&self) -> Result<SyncReport, SyncError> {
168        let mut report = self.pull().await?;
169        let push_report = self.push().await?;
170        report.pushed = push_report.pushed;
171        report.skipped.extend(push_report.skipped);
172        Ok(report)
173    }
174
175    /// Watch for local credential changes and sync periodically.
176    pub async fn watch(&self) -> Result<(), SyncError> {
177        use tokio::time::{interval, Duration};
178        info!(
179            interval_secs = self.config.watch_interval_secs,
180            "starting watch mode"
181        );
182
183        let mut tick = interval(Duration::from_secs(self.config.watch_interval_secs));
184        loop {
185            tick.tick().await;
186            match self.sync().await {
187                Ok(report) => {
188                    if !report.pushed.is_empty() || !report.pulled.is_empty() {
189                        info!(?report, "sync cycle completed with changes");
190                    } else {
191                        debug!("sync cycle: no changes");
192                    }
193                }
194                Err(e) => {
195                    warn!(error = %e, "sync cycle failed, will retry next interval");
196                }
197            }
198        }
199    }
200
201    /// List all providers and their credential status.
202    pub async fn status(&self) -> Vec<ProviderStatus> {
203        let mut statuses = Vec::new();
204        for provider in &self.providers {
205            let validation = provider.validate().await;
206            let files: Vec<_> = provider
207                .credential_files()
208                .into_iter()
209                .map(|c| {
210                    let repo_exists = self.config.local_path.join(&c.relative_path).exists();
211                    FileStatus {
212                        relative_path: c.relative_path,
213                        local_exists: c.local_path.exists(),
214                        repo_exists,
215                    }
216                })
217                .collect();
218            statuses.push(ProviderStatus {
219                name: provider.name().to_string(),
220                display_name: provider.display_name().to_string(),
221                validation,
222                files,
223            });
224        }
225        statuses
226    }
227}
228
229/// Report of a sync operation.
230#[derive(Debug, Default)]
231pub struct SyncReport {
232    pub pulled: Vec<String>,
233    pub pushed: Vec<String>,
234    pub skipped: Vec<String>,
235}
236
237/// Status of a provider's credentials.
238#[derive(Debug)]
239pub struct ProviderStatus {
240    pub name: String,
241    pub display_name: String,
242    pub validation: ValidationResult,
243    pub files: Vec<FileStatus>,
244}
245
246/// Status of a single credential file.
247#[derive(Debug)]
248pub struct FileStatus {
249    pub relative_path: String,
250    pub local_exists: bool,
251    pub repo_exists: bool,
252}
253
254/// Resolve provider list: if empty, use all; otherwise look up by name.
255fn resolve_providers(names: &[String]) -> Result<Vec<Box<dyn AuthProvider>>, SyncError> {
256    if names.is_empty() {
257        return Ok(providers::all_providers());
258    }
259    names
260        .iter()
261        .map(|name| {
262            providers::provider_by_name(name)
263                .ok_or_else(|| SyncError::ProviderNotFound(name.clone()))
264        })
265        .collect()
266}
267
268/// Recursively copy a file or directory, creating parent dirs as needed.
269fn copy_recursive<'a>(
270    src: &'a Path,
271    dst: &'a Path,
272) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), SyncError>> + Send + 'a>> {
273    Box::pin(async move {
274        if src.is_dir() {
275            tokio::fs::create_dir_all(dst).await?;
276            let mut entries = tokio::fs::read_dir(src).await?;
277            while let Some(entry) = entries.next_entry().await? {
278                let entry_path = entry.path();
279                let file_name = entry.file_name().to_string_lossy().to_string();
280                let dst_child = dst.join(&file_name);
281                copy_recursive(&entry_path, &dst_child).await?;
282            }
283        } else if src.is_file() {
284            if let Some(parent) = dst.parent() {
285                tokio::fs::create_dir_all(parent).await?;
286            }
287            tokio::fs::copy(src, dst).await?;
288        }
289        Ok(())
290    })
291}
292
293/// Helper to get a PathBuf for credential files (used by providers).
294pub fn _credential_path(relative: &str) -> CredentialFile {
295    CredentialFile {
296        relative_path: relative.to_string(),
297        local_path: PathBuf::from(relative),
298        is_dir: false,
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_resolve_all_providers() {
308        let providers = resolve_providers(&[]).unwrap();
309        assert_eq!(providers.len(), 7);
310    }
311
312    #[test]
313    fn test_resolve_specific_providers() {
314        let names = vec!["gh".to_string(), "claude".to_string()];
315        let providers = resolve_providers(&names).unwrap();
316        assert_eq!(providers.len(), 2);
317        assert_eq!(providers[0].name(), "gh");
318        assert_eq!(providers[1].name(), "claude");
319    }
320
321    #[test]
322    fn test_resolve_unknown_provider() {
323        let names = vec!["nonexistent".to_string()];
324        let result = resolve_providers(&names);
325        assert!(result.is_err());
326    }
327
328    #[test]
329    fn test_new_engine_requires_repo_url() {
330        let config = SyncConfig::default();
331        let result = SyncEngine::new(config);
332        assert!(result.is_err());
333    }
334
335    #[tokio::test]
336    async fn test_copy_recursive_file() {
337        let tmp = tempfile::tempdir().unwrap();
338        let src = tmp.path().join("src.txt");
339        let dst = tmp.path().join("nested").join("dst.txt");
340        tokio::fs::write(&src, "hello").await.unwrap();
341        copy_recursive(&src, &dst).await.unwrap();
342        let content = tokio::fs::read_to_string(&dst).await.unwrap();
343        assert_eq!(content, "hello");
344    }
345
346    #[tokio::test]
347    async fn test_copy_recursive_dir() {
348        let tmp = tempfile::tempdir().unwrap();
349        let src_dir = tmp.path().join("src_dir");
350        let dst_dir = tmp.path().join("dst_dir");
351        tokio::fs::create_dir_all(&src_dir).await.unwrap();
352        tokio::fs::write(src_dir.join("a.txt"), "aaa")
353            .await
354            .unwrap();
355        tokio::fs::write(src_dir.join("b.txt"), "bbb")
356            .await
357            .unwrap();
358        copy_recursive(&src_dir, &dst_dir).await.unwrap();
359        assert_eq!(
360            tokio::fs::read_to_string(dst_dir.join("a.txt"))
361                .await
362                .unwrap(),
363            "aaa"
364        );
365        assert_eq!(
366            tokio::fs::read_to_string(dst_dir.join("b.txt"))
367                .await
368                .unwrap(),
369            "bbb"
370        );
371    }
372}