1use crate::backend::{GitBackend, GitRepo};
4use crate::providers;
5use crate::{AuthProvider, CredentialFile, SyncConfig, SyncError, ValidationResult};
6use std::path::{Path, PathBuf};
7use tracing::{debug, info, warn};
8
9pub struct SyncEngine {
12 pub config: SyncConfig,
13 pub backend: Box<dyn GitBackend>,
14 pub providers: Vec<Box<dyn AuthProvider>>,
15}
16
17impl SyncEngine {
18 pub fn new(config: SyncConfig) -> Result<Self, SyncError> {
20 if config.repo_url.is_empty() {
21 return Err(SyncError::Config("repo_url must not be empty".to_string()));
22 }
23
24 let active_providers = resolve_providers(&config.providers)?;
25 info!(
26 providers = ?active_providers.iter().map(|p| p.name()).collect::<Vec<_>>(),
27 "initialized sync engine"
28 );
29
30 Ok(Self {
31 config,
32 backend: Box::new(GitRepo),
33 providers: active_providers,
34 })
35 }
36
37 pub fn with_backend(
39 config: SyncConfig,
40 backend: Box<dyn GitBackend>,
41 ) -> Result<Self, SyncError> {
42 let active_providers = resolve_providers(&config.providers)?;
43 Ok(Self {
44 config,
45 backend,
46 providers: active_providers,
47 })
48 }
49
50 pub async fn ensure_repo(&self) -> Result<(), SyncError> {
52 if self.backend.is_cloned(&self.config.local_path) {
53 debug!(path = %self.config.local_path.display(), "repo already cloned");
54 return Ok(());
55 }
56 self.backend
57 .clone_repo(
58 &self.config.repo_url,
59 &self.config.local_path,
60 self.config.shallow_clone,
61 )
62 .await
63 }
64
65 pub async fn pull(&self) -> Result<SyncReport, SyncError> {
71 self.ensure_repo().await?;
72 self.backend.pull(&self.config.local_path).await?;
73
74 let mut report = SyncReport::default();
75
76 for provider in &self.providers {
77 for cred in provider.credential_files() {
78 let repo_path = self.config.local_path.join(&cred.relative_path);
79 if repo_path.exists() {
80 copy_recursive(&repo_path, &cred.local_path).await?;
81 report.pulled.push(cred.relative_path.clone());
82 info!(
83 provider = provider.name(),
84 path = %cred.local_path.display(),
85 "pulled credential"
86 );
87 } else {
88 debug!(
89 provider = provider.name(),
90 repo_path = %repo_path.display(),
91 "no credential in repo, skipping"
92 );
93 }
94 }
95 }
96
97 Ok(report)
98 }
99
100 pub async fn push(&self) -> Result<SyncReport, SyncError> {
107 self.ensure_repo().await?;
108
109 if self.backend.is_cloned(&self.config.local_path) {
111 let _ = self.backend.pull(&self.config.local_path).await;
112 }
113
114 let mut report = SyncReport::default();
115
116 for provider in &self.providers {
117 let validation = provider.validate().await;
119 if validation == ValidationResult::Expired {
120 warn!(
121 provider = provider.name(),
122 "skipping push: credentials are expired"
123 );
124 report
125 .skipped
126 .push(format!("{}: credentials expired", provider.name()));
127 continue;
128 }
129
130 for cred in provider.credential_files() {
131 if cred.local_path.exists() {
132 let repo_path = self.config.local_path.join(&cred.relative_path);
133 copy_recursive(&cred.local_path, &repo_path).await?;
134 report.pushed.push(cred.relative_path.clone());
135 info!(
136 provider = provider.name(),
137 path = %cred.local_path.display(),
138 "staged credential for push"
139 );
140 } else {
141 debug!(
142 provider = provider.name(),
143 path = %cred.local_path.display(),
144 "local credential not found, skipping"
145 );
146 }
147 }
148 }
149
150 if !report.pushed.is_empty() {
151 let message = format!(
152 "sync-auth: update credentials ({})",
153 report
154 .pushed
155 .iter()
156 .map(String::as_str)
157 .collect::<Vec<_>>()
158 .join(", ")
159 );
160 self.backend.push(&self.config.local_path, &message).await?;
161 }
162
163 Ok(report)
164 }
165
166 pub async fn sync(&self) -> Result<SyncReport, SyncError> {
168 let mut report = self.pull().await?;
169 let push_report = self.push().await?;
170 report.pushed = push_report.pushed;
171 report.skipped.extend(push_report.skipped);
172 Ok(report)
173 }
174
175 pub async fn watch(&self) -> Result<(), SyncError> {
177 use tokio::time::{interval, Duration};
178 info!(
179 interval_secs = self.config.watch_interval_secs,
180 "starting watch mode"
181 );
182
183 let mut tick = interval(Duration::from_secs(self.config.watch_interval_secs));
184 loop {
185 tick.tick().await;
186 match self.sync().await {
187 Ok(report) => {
188 if !report.pushed.is_empty() || !report.pulled.is_empty() {
189 info!(?report, "sync cycle completed with changes");
190 } else {
191 debug!("sync cycle: no changes");
192 }
193 }
194 Err(e) => {
195 warn!(error = %e, "sync cycle failed, will retry next interval");
196 }
197 }
198 }
199 }
200
201 pub async fn status(&self) -> Vec<ProviderStatus> {
203 let mut statuses = Vec::new();
204 for provider in &self.providers {
205 let validation = provider.validate().await;
206 let files: Vec<_> = provider
207 .credential_files()
208 .into_iter()
209 .map(|c| {
210 let repo_exists = self.config.local_path.join(&c.relative_path).exists();
211 FileStatus {
212 relative_path: c.relative_path,
213 local_exists: c.local_path.exists(),
214 repo_exists,
215 }
216 })
217 .collect();
218 statuses.push(ProviderStatus {
219 name: provider.name().to_string(),
220 display_name: provider.display_name().to_string(),
221 validation,
222 files,
223 });
224 }
225 statuses
226 }
227}
228
229#[derive(Debug, Default)]
231pub struct SyncReport {
232 pub pulled: Vec<String>,
233 pub pushed: Vec<String>,
234 pub skipped: Vec<String>,
235}
236
237#[derive(Debug)]
239pub struct ProviderStatus {
240 pub name: String,
241 pub display_name: String,
242 pub validation: ValidationResult,
243 pub files: Vec<FileStatus>,
244}
245
246#[derive(Debug)]
248pub struct FileStatus {
249 pub relative_path: String,
250 pub local_exists: bool,
251 pub repo_exists: bool,
252}
253
254fn resolve_providers(names: &[String]) -> Result<Vec<Box<dyn AuthProvider>>, SyncError> {
256 if names.is_empty() {
257 return Ok(providers::all_providers());
258 }
259 names
260 .iter()
261 .map(|name| {
262 providers::provider_by_name(name)
263 .ok_or_else(|| SyncError::ProviderNotFound(name.clone()))
264 })
265 .collect()
266}
267
268fn copy_recursive<'a>(
270 src: &'a Path,
271 dst: &'a Path,
272) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), SyncError>> + Send + 'a>> {
273 Box::pin(async move {
274 if src.is_dir() {
275 tokio::fs::create_dir_all(dst).await?;
276 let mut entries = tokio::fs::read_dir(src).await?;
277 while let Some(entry) = entries.next_entry().await? {
278 let entry_path = entry.path();
279 let file_name = entry.file_name().to_string_lossy().to_string();
280 let dst_child = dst.join(&file_name);
281 copy_recursive(&entry_path, &dst_child).await?;
282 }
283 } else if src.is_file() {
284 if let Some(parent) = dst.parent() {
285 tokio::fs::create_dir_all(parent).await?;
286 }
287 tokio::fs::copy(src, dst).await?;
288 }
289 Ok(())
290 })
291}
292
293pub fn _credential_path(relative: &str) -> CredentialFile {
295 CredentialFile {
296 relative_path: relative.to_string(),
297 local_path: PathBuf::from(relative),
298 is_dir: false,
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_resolve_all_providers() {
308 let providers = resolve_providers(&[]).unwrap();
309 assert_eq!(providers.len(), 7);
310 }
311
312 #[test]
313 fn test_resolve_specific_providers() {
314 let names = vec!["gh".to_string(), "claude".to_string()];
315 let providers = resolve_providers(&names).unwrap();
316 assert_eq!(providers.len(), 2);
317 assert_eq!(providers[0].name(), "gh");
318 assert_eq!(providers[1].name(), "claude");
319 }
320
321 #[test]
322 fn test_resolve_unknown_provider() {
323 let names = vec!["nonexistent".to_string()];
324 let result = resolve_providers(&names);
325 assert!(result.is_err());
326 }
327
328 #[test]
329 fn test_new_engine_requires_repo_url() {
330 let config = SyncConfig::default();
331 let result = SyncEngine::new(config);
332 assert!(result.is_err());
333 }
334
335 #[tokio::test]
336 async fn test_copy_recursive_file() {
337 let tmp = tempfile::tempdir().unwrap();
338 let src = tmp.path().join("src.txt");
339 let dst = tmp.path().join("nested").join("dst.txt");
340 tokio::fs::write(&src, "hello").await.unwrap();
341 copy_recursive(&src, &dst).await.unwrap();
342 let content = tokio::fs::read_to_string(&dst).await.unwrap();
343 assert_eq!(content, "hello");
344 }
345
346 #[tokio::test]
347 async fn test_copy_recursive_dir() {
348 let tmp = tempfile::tempdir().unwrap();
349 let src_dir = tmp.path().join("src_dir");
350 let dst_dir = tmp.path().join("dst_dir");
351 tokio::fs::create_dir_all(&src_dir).await.unwrap();
352 tokio::fs::write(src_dir.join("a.txt"), "aaa")
353 .await
354 .unwrap();
355 tokio::fs::write(src_dir.join("b.txt"), "bbb")
356 .await
357 .unwrap();
358 copy_recursive(&src_dir, &dst_dir).await.unwrap();
359 assert_eq!(
360 tokio::fs::read_to_string(dst_dir.join("a.txt"))
361 .await
362 .unwrap(),
363 "aaa"
364 );
365 assert_eq!(
366 tokio::fs::read_to_string(dst_dir.join("b.txt"))
367 .await
368 .unwrap(),
369 "bbb"
370 );
371 }
372}