1use std::collections::BTreeMap;
42use std::fs;
43use std::path::{Path, PathBuf};
44
45use crate::error::{Result, ToriiError};
46
47const CLOUD_ENV_VAR: &str = "TORII_API_KEY";
48const FILE_NAME: &str = "auth.toml";
49
50#[derive(Debug, Clone, Default)]
53pub struct ApiKey {
54 pub key: String,
55 pub endpoint: String,
56}
57
58#[derive(Debug, Clone, Default)]
68pub struct AuthStore {
69 pub cloud: Option<ApiKey>,
70 pub tokens: BTreeMap<String, String>,
71 pub expirations: BTreeMap<String, String>,
72 pub refresh_tokens: BTreeMap<String, String>,
79}
80
81pub const PROVIDERS: &[&str] = &[
85 "github",
86 "gitlab",
87 "gitea",
88 "forgejo",
89 "codeberg",
90 "bitbucket",
91 "sourcehut",
92 "azure", "cargo",
94];
95
96pub fn default_endpoint() -> String {
99 std::env::var("TORII_API_ENDPOINT")
100 .unwrap_or_else(|_| "https://api.gitorii.com".to_string())
101}
102
103fn global_path() -> Option<PathBuf> {
106 dirs::config_dir().map(|d| d.join("torii").join(FILE_NAME))
107}
108
109fn local_path<P: AsRef<Path>>(repo_path: P) -> PathBuf {
110 repo_path.as_ref().join(".torii").join(FILE_NAME)
111}
112
113pub fn load() -> Option<ApiKey> {
117 if let Ok(env_key) = std::env::var(CLOUD_ENV_VAR) {
118 if !env_key.is_empty() {
119 return Some(ApiKey {
120 key: env_key,
121 endpoint: default_endpoint(),
122 });
123 }
124 }
125 load_global().cloud
126}
127
128pub fn load_global() -> AuthStore {
131 let Some(path) = global_path() else {
132 return AuthStore::default();
133 };
134 if !path.exists() {
135 return migrate_from_config_toml().unwrap_or_default();
138 }
139 let text = match fs::read_to_string(&path) {
140 Ok(t) => t,
141 Err(_) => return AuthStore::default(),
142 };
143 parse(&text)
144}
145
146pub fn load_local_raw<P: AsRef<Path>>(repo_path: P) -> AuthStore {
150 let path = local_path(repo_path);
151 if !path.exists() {
152 return AuthStore::default();
153 }
154 let text = match fs::read_to_string(&path) {
155 Ok(t) => t,
156 Err(_) => return AuthStore::default(),
157 };
158 parse(&text)
159}
160
161fn save_to(path: &Path, store: &AuthStore) -> Result<()> {
165 if let Some(parent) = path.parent() {
166 fs::create_dir_all(parent)
167 .map_err(|e| ToriiError::Fs(format!("create dir: {}", e)))?;
168 }
169 let mut out = String::new();
170 out.push_str("# torii credentials — managed by 'torii auth …'. Do not share.\n\n");
171 if let Some(cloud) = &store.cloud {
172 out.push_str("[cloud]\n");
173 out.push_str(&format!("key = \"{}\"\n", cloud.key));
174 out.push_str(&format!("endpoint = \"{}\"\n\n", cloud.endpoint));
175 }
176 if !store.tokens.is_empty() {
177 out.push_str("[tokens]\n");
178 for (k, v) in &store.tokens {
179 out.push_str(&format!("{} = \"{}\"\n", k, v));
180 }
181 out.push('\n');
182 }
183 if !store.expirations.is_empty() {
184 out.push_str("[token_expires]\n");
185 for (k, v) in &store.expirations {
186 out.push_str(&format!("{} = \"{}\"\n", k, v));
187 }
188 out.push('\n');
189 }
190 if !store.refresh_tokens.is_empty() {
191 out.push_str("[token_refresh]\n");
192 for (k, v) in &store.refresh_tokens {
193 out.push_str(&format!("{} = \"{}\"\n", k, v));
194 }
195 }
196 fs::write(path, out)
197 .map_err(|e| ToriiError::Fs(format!("write {}: {}", path.display(), e)))?;
198 restrict_permissions(path);
199 Ok(())
200}
201
202pub fn save_global(store: &AuthStore) -> Result<()> {
203 let path = global_path()
204 .ok_or_else(|| ToriiError::InvalidConfig("could not resolve config dir".to_string()))?;
205 save_to(&path, store)
206}
207
208pub fn save_local<P: AsRef<Path>>(repo_path: P, store: &AuthStore) -> Result<()> {
209 let path = local_path(repo_path);
210 save_to(&path, store)
211}
212
213pub fn save_cloud(key: &str, endpoint: &str) -> Result<()> {
216 let mut store = load_global();
217 store.cloud = Some(ApiKey {
218 key: key.to_string(),
219 endpoint: endpoint.to_string(),
220 });
221 save_global(&store)
222}
223
224pub fn delete() -> Result<()> {
226 let mut store = load_global();
227 store.cloud = None;
228 if store.tokens.is_empty() {
229 if let Some(path) = global_path() {
231 if path.exists() {
232 fs::remove_file(&path).map_err(|e| {
233 ToriiError::Fs(format!("remove {}: {}", path.display(), e))
234 })?;
235 }
236 }
237 return Ok(());
238 }
239 save_global(&store)
240}
241
242pub fn normalise_provider(name: &str) -> Result<String> {
247 let lc = name.to_lowercase();
248 if PROVIDERS.iter().any(|p| **p == lc) {
249 Ok(lc)
250 } else {
251 Err(ToriiError::Usage(format!(
252 "unknown provider '{}'. Known: {}",
253 name,
254 PROVIDERS.join(", ")
255 )))
256 }
257}
258
259pub fn set_token(provider: &str, token: &str, local: Option<&Path>) -> Result<()> {
260 set_token_with_expiry(provider, token, None, local)
261}
262
263pub fn set_token_with_expiry(
269 provider: &str,
270 token: &str,
271 expires_at: Option<&str>,
272 local: Option<&Path>,
273) -> Result<()> {
274 let provider = normalise_provider(provider)?;
275 let result = if let Some(repo) = local {
276 let mut store = load_local_raw(repo);
277 store.tokens.insert(provider.clone(), token.to_string());
278 apply_expiry(&mut store.expirations, &provider, expires_at);
279 save_local(repo, &store)
280 } else {
281 let mut store = load_global();
282 store.tokens.insert(provider.clone(), token.to_string());
283 apply_expiry(&mut store.expirations, &provider, expires_at);
284 save_global(&store)
285 };
286 invalidate_token_cache();
287 result
288}
289
290pub fn set_token_with_refresh(
296 provider: &str,
297 access_token: &str,
298 refresh_token: Option<&str>,
299 expires_in_seconds: Option<u64>,
300) -> Result<()> {
301 let provider = normalise_provider(provider)?;
302 let expires_at = expires_in_seconds.map(|s| {
303 let when = chrono::Utc::now() + chrono::Duration::seconds(s as i64);
304 when.to_rfc3339_opts(chrono::SecondsFormat::Secs, true)
305 });
306 let mut store = load_global();
307 store.tokens.insert(provider.clone(), access_token.to_string());
308 apply_expiry(&mut store.expirations, &provider, expires_at.as_deref());
309 if let Some(r) = refresh_token {
310 store.refresh_tokens.insert(provider.clone(), r.to_string());
311 }
312 save_global(&store)?;
313 invalidate_token_cache();
314 Ok(())
315}
316
317pub fn refresh_if_needed(provider: &str) -> Result<bool> {
324 let provider_lc = provider.to_lowercase();
325 let store = load_global();
326 let Some(refresh) = store.refresh_tokens.get(&provider_lc).cloned() else {
327 return Ok(false);
328 };
329 let due = store.expirations.get(&provider_lc)
333 .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
334 .map(|when| {
335 let now = chrono::Utc::now();
336 when.with_timezone(&chrono::Utc) - now < chrono::Duration::minutes(5)
337 })
338 .unwrap_or(false);
339 if !due { return Ok(false); }
340
341 let (new_access, new_refresh, expires_in) =
342 crate::oauth::refresh_access_token(&provider_lc, &refresh)?;
343 set_token_with_refresh(
344 &provider_lc,
345 &new_access,
346 new_refresh.as_deref().or(Some(&refresh)),
347 expires_in,
348 )?;
349 Ok(true)
350}
351
352fn apply_expiry(map: &mut BTreeMap<String, String>, provider: &str, expires_at: Option<&str>) {
353 match expires_at {
354 Some(s) if !s.is_empty() => { map.insert(provider.to_string(), s.to_string()); }
355 _ => { map.remove(provider); }
356 }
357}
358
359pub fn token_expires_at(provider: &str) -> Option<String> {
364 let store = load_global();
365 store.expirations.get(&provider.to_lowercase()).cloned()
366}
367
368pub fn remove_token(provider: &str, local: Option<&Path>) -> Result<bool> {
369 let provider = normalise_provider(provider)?;
370 let removed = if let Some(repo) = local {
371 let mut store = load_local_raw(repo);
372 let r = store.tokens.remove(&provider).is_some();
373 store.expirations.remove(&provider);
374 save_local(repo, &store)?;
375 r
376 } else {
377 let mut store = load_global();
378 let r = store.tokens.remove(&provider).is_some();
379 store.expirations.remove(&provider);
380 save_global(&store)?;
381 r
382 };
383 invalidate_token_cache();
384 Ok(removed)
385}
386
387#[derive(Debug, Clone, PartialEq, Eq)]
391pub enum TokenSource {
392 EnvVar(&'static str),
393 EnvGeneric,
394 Local,
395 Global,
396 Missing,
397}
398
399#[derive(Debug, Clone)]
400pub struct ResolvedToken {
401 #[allow(dead_code)]
405 pub provider: String,
406 pub value: Option<String>,
407 pub source: TokenSource,
408}
409
410fn token_cache() -> &'static std::sync::Mutex<std::collections::HashMap<String, ResolvedToken>> {
420 static CACHE: std::sync::OnceLock<std::sync::Mutex<std::collections::HashMap<String, ResolvedToken>>> = std::sync::OnceLock::new();
421 CACHE.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()))
422}
423
424fn invalidate_token_cache() {
425 if let Ok(mut g) = token_cache().lock() {
426 g.clear();
427 }
428}
429
430pub fn drop_token_cache() {
437 invalidate_token_cache();
438}
439
440pub fn resolve_token<P: AsRef<Path>>(provider: &str, repo_path: P) -> ResolvedToken {
448 let _ = refresh_if_needed(provider);
457
458 let key = format!("{}|{}", provider.to_lowercase(), repo_path.as_ref().display());
459 if let Ok(g) = token_cache().lock() {
460 if let Some(hit) = g.get(&key) {
461 return hit.clone();
462 }
463 }
464 let result = resolve_token_uncached(provider, repo_path);
465 if let Ok(mut g) = token_cache().lock() {
466 g.insert(key, result.clone());
467 }
468 result
469}
470
471fn resolve_token_uncached<P: AsRef<Path>>(provider: &str, repo_path: P) -> ResolvedToken {
472 let provider_lc = provider.to_lowercase();
473
474 for env_name in env_vars_for(&provider_lc) {
476 if let Ok(v) = std::env::var(env_name) {
477 if !v.is_empty() {
478 return ResolvedToken {
479 provider: provider_lc,
480 value: Some(v),
481 source: TokenSource::EnvVar(env_name),
482 };
483 }
484 }
485 }
486
487 if let Ok(v) = std::env::var("TORII_HTTPS_TOKEN") {
489 if !v.is_empty() {
490 return ResolvedToken {
491 provider: provider_lc,
492 value: Some(v),
493 source: TokenSource::EnvGeneric,
494 };
495 }
496 }
497
498 let local = load_local_raw(repo_path);
500 if let Some(v) = local.tokens.get(&provider_lc) {
501 if !v.is_empty() {
502 return ResolvedToken {
503 provider: provider_lc,
504 value: Some(v.clone()),
505 source: TokenSource::Local,
506 };
507 }
508 }
509
510 let global = load_global();
512 if let Some(v) = global.tokens.get(&provider_lc) {
513 if !v.is_empty() {
514 return ResolvedToken {
515 provider: provider_lc,
516 value: Some(v.clone()),
517 source: TokenSource::Global,
518 };
519 }
520 }
521
522 ResolvedToken {
523 provider: provider_lc,
524 value: None,
525 source: TokenSource::Missing,
526 }
527}
528
529fn env_vars_for(provider: &str) -> &'static [&'static str] {
533 match provider {
534 "github" => &["GITHUB_TOKEN", "GH_TOKEN"],
535 "gitlab" => &["GITLAB_TOKEN", "GL_TOKEN"],
536 "gitea" => &["GITEA_TOKEN"],
537 "forgejo" => &["FORGEJO_TOKEN"],
538 "codeberg" => &["CODEBERG_TOKEN"],
539 "bitbucket" => &["BITBUCKET_TOKEN"],
540 "azure" => &["AZURE_DEVOPS_TOKEN", "AZURE_DEVOPS_EXT_PAT", "AZDO_TOKEN"],
541 "sourcehut" => &["SOURCEHUT_TOKEN", "SRHT_TOKEN"],
542 "cargo" => &["CARGO_REGISTRY_TOKEN"],
543 _ => &[],
544 }
545}
546
547fn parse(text: &str) -> AuthStore {
552 enum Section {
553 TopLevel,
554 Cloud,
555 Tokens,
556 TokenExpires,
557 TokenRefresh,
558 }
559 let mut section = Section::TopLevel;
560 let mut cloud_key = String::new();
561 let mut cloud_endpoint = default_endpoint();
562 let mut have_cloud = false;
563 let mut tokens = BTreeMap::new();
564 let mut expirations = BTreeMap::new();
565 let mut refresh_tokens = BTreeMap::new();
566
567 for raw in text.lines() {
568 let line = raw.trim();
569 if line.is_empty() || line.starts_with('#') {
570 continue;
571 }
572 if line.starts_with('[') && line.ends_with(']') {
573 let name = &line[1..line.len() - 1];
574 section = match name.trim() {
575 "cloud" => Section::Cloud,
576 "tokens" => Section::Tokens,
577 "token_expires" => Section::TokenExpires,
578 "token_refresh" => Section::TokenRefresh,
579 _ => Section::TopLevel, };
581 continue;
582 }
583 let Some((k, v)) = line.split_once('=') else {
584 continue;
585 };
586 let k = k.trim();
587 let v = v.trim().trim_matches('"').to_string();
588 match section {
589 Section::Cloud | Section::TopLevel => match k {
590 "key" => {
591 cloud_key = v;
592 have_cloud = true;
593 }
594 "endpoint" => {
595 cloud_endpoint = v;
596 }
597 _ => {}
598 },
599 Section::Tokens => {
600 if !v.is_empty() {
601 tokens.insert(k.to_string(), v);
602 }
603 }
604 Section::TokenExpires => {
605 if !v.is_empty() {
606 expirations.insert(k.to_string(), v);
607 }
608 }
609 Section::TokenRefresh => {
610 if !v.is_empty() {
611 refresh_tokens.insert(k.to_string(), v);
612 }
613 }
614 }
615 }
616
617 AuthStore {
618 cloud: if have_cloud && !cloud_key.is_empty() {
619 Some(ApiKey {
620 key: cloud_key,
621 endpoint: cloud_endpoint,
622 })
623 } else {
624 None
625 },
626 tokens,
627 expirations,
628 refresh_tokens,
629 }
630}
631
632fn migrate_from_config_toml() -> Option<AuthStore> {
639 let config_path = dirs::config_dir()?.join("torii").join("config.toml");
640 if !config_path.exists() {
641 return None;
642 }
643 let text = fs::read_to_string(&config_path).ok()?;
644
645 let mut tokens = BTreeMap::new();
646 let mut in_auth = false;
647 for raw in text.lines() {
648 let line = raw.trim();
649 if line.is_empty() || line.starts_with('#') {
650 continue;
651 }
652 if line.starts_with('[') && line.ends_with(']') {
653 in_auth = line.trim_start_matches('[').trim_end_matches(']').trim() == "auth";
654 continue;
655 }
656 if !in_auth {
657 continue;
658 }
659 let Some((k, v)) = line.split_once('=') else {
660 continue;
661 };
662 let key = k.trim();
663 let value = v.trim().trim_matches('"').to_string();
664 if value.is_empty() {
665 continue;
666 }
667 if let Some(provider) = key.strip_suffix("_token") {
670 tokens.insert(provider.to_string(), value);
671 }
672 }
673 if tokens.is_empty() {
674 return None;
675 }
676 let store = AuthStore {
677 cloud: None,
678 tokens,
679 expirations: BTreeMap::new(),
680 refresh_tokens: BTreeMap::new(),
681 };
682 let _ = save_global(&store);
683 Some(store)
684}
685
686#[cfg(unix)]
689fn restrict_permissions(path: &std::path::Path) {
690 use std::os::unix::fs::PermissionsExt;
691 let _ = fs::set_permissions(path, fs::Permissions::from_mode(0o600));
692}
693
694#[cfg(not(unix))]
695fn restrict_permissions(_: &std::path::Path) {}
696
697#[cfg(test)]
700mod tests {
701 use super::*;
702
703 #[test]
704 fn parse_legacy_top_level_cloud() {
705 let s = parse("key = \"gitorii_sk_abc\"");
706 assert_eq!(s.cloud.as_ref().unwrap().key, "gitorii_sk_abc");
707 assert!(s.tokens.is_empty());
708 }
709
710 #[test]
711 fn parse_new_sectioned_cloud_only() {
712 let s = parse("[cloud]\nkey = \"x\"\nendpoint = \"http://h\"\n");
713 let c = s.cloud.unwrap();
714 assert_eq!(c.key, "x");
715 assert_eq!(c.endpoint, "http://h");
716 }
717
718 #[test]
719 fn parse_tokens_only() {
720 let s = parse("[tokens]\ngithub = \"ghp_x\"\ngitlab = \"glp_y\"\n");
721 assert_eq!(s.tokens["github"], "ghp_x");
722 assert_eq!(s.tokens["gitlab"], "glp_y");
723 assert!(s.cloud.is_none());
724 }
725
726 #[test]
727 fn parse_both_sections() {
728 let s = parse("[cloud]\nkey = \"k\"\n[tokens]\ncargo = \"cio\"\n");
729 assert_eq!(s.cloud.unwrap().key, "k");
730 assert_eq!(s.tokens["cargo"], "cio");
731 }
732
733 #[test]
734 fn parse_empty_tokens_are_dropped() {
735 let s = parse("[tokens]\ngithub = \"\"\ngitlab = \"x\"\n");
736 assert!(!s.tokens.contains_key("github"));
737 assert!(s.tokens.contains_key("gitlab"));
738 }
739
740 #[test]
741 fn normalise_provider_accepts_known() {
742 assert_eq!(normalise_provider("GitHub").unwrap(), "github");
743 assert_eq!(normalise_provider("cargo").unwrap(), "cargo");
744 }
745
746 #[test]
747 fn normalise_provider_rejects_unknown() {
748 assert!(normalise_provider("hackernews").is_err());
749 }
750}