1use std::path::{Path, PathBuf};
49
50use serde::Deserialize;
51
52use crate::errors::{SafeError, SafeResult};
53
54#[derive(Debug, Clone, Deserialize, PartialEq)]
64#[serde(tag = "method", rename_all = "lowercase")]
65pub enum VaultAuthConfig {
66 Token {
68 #[serde(default)]
69 token: Option<String>,
70 },
71 Approle { role_id: String, secret_id: String },
73}
74
75impl VaultAuthConfig {
76 pub fn expand_env_vars(self) -> Self {
79 match self {
80 VaultAuthConfig::Approle { role_id, secret_id } => VaultAuthConfig::Approle {
81 role_id: expand_env_var_str(&role_id),
82 secret_id: expand_env_var_str(&secret_id),
83 },
84 other => other,
85 }
86 }
87}
88
89pub fn expand_env_var_str(s: &str) -> String {
95 if !s.contains("${") {
97 return s.to_string();
98 }
99
100 let mut result = String::with_capacity(s.len());
101 let mut rest = s;
102
103 while let Some(start) = rest.find("${") {
104 result.push_str(&rest[..start]);
105 rest = &rest[start + 2..]; if let Some(end) = rest.find('}') {
107 let var_name = &rest[..end];
108 match std::env::var(var_name) {
109 Ok(val) => result.push_str(&val),
110 Err(_) => {
111 result.push_str("${");
113 result.push_str(var_name);
114 result.push('}');
115 }
116 }
117 rest = &rest[end + 1..];
118 } else {
119 result.push_str("${");
121 result.push_str(rest);
122 break;
123 }
124 }
125 result.push_str(rest);
126 result
127}
128
129#[derive(Debug, Deserialize)]
131pub struct PullConfig {
132 pub pulls: Vec<PullSource>,
133}
134
135#[derive(Debug, Deserialize)]
141#[serde(tag = "source")]
142pub enum PullSource {
143 #[serde(rename = "akv")]
145 AzureKeyVault {
146 #[serde(default)]
148 name: Option<String>,
149 #[serde(default)]
151 ns: Option<String>,
152 vault_url: String,
153 #[serde(default)]
154 prefix: Option<String>,
155 #[serde(default)]
156 overwrite: bool,
157 },
158 #[serde(rename = "hcp")]
160 HashiCorpVault {
161 #[serde(default)]
163 name: Option<String>,
164 #[serde(default)]
166 ns: Option<String>,
167 #[serde(default = "default_hcp_addr")]
168 addr: String,
169 #[serde(default = "default_mount")]
170 mount: String,
171 #[serde(default)]
172 prefix: Option<String>,
173 #[serde(default)]
174 overwrite: bool,
175 #[serde(default)]
178 auth: Option<VaultAuthConfig>,
179 #[serde(default)]
182 vault_namespace: Option<String>,
183 },
184 #[serde(rename = "op")]
186 OnePassword {
187 #[serde(default)]
189 name: Option<String>,
190 #[serde(default)]
192 ns: Option<String>,
193 item: String,
194 #[serde(default)]
195 op_vault: Option<String>,
196 #[serde(default)]
197 overwrite: bool,
198 },
199 #[serde(rename = "aws")]
201 Aws {
202 #[serde(default)]
204 name: Option<String>,
205 #[serde(default)]
207 ns: Option<String>,
208 #[serde(default)]
210 region: Option<String>,
211 #[serde(default)]
213 prefix: Option<String>,
214 #[serde(default)]
215 overwrite: bool,
216 },
217 #[serde(rename = "ssm")]
219 SsmParameterStore {
220 #[serde(default)]
222 name: Option<String>,
223 #[serde(default)]
225 ns: Option<String>,
226 #[serde(default)]
228 region: Option<String>,
229 #[serde(default)]
231 path: Option<String>,
232 #[serde(default)]
233 overwrite: bool,
234 },
235 #[serde(rename = "gcp")]
237 Gcp {
238 #[serde(default)]
240 name: Option<String>,
241 #[serde(default)]
243 ns: Option<String>,
244 #[serde(default)]
246 project: Option<String>,
247 #[serde(default)]
249 prefix: Option<String>,
250 #[serde(default)]
251 overwrite: bool,
252 },
253 #[serde(rename = "bw")]
262 Bitwarden {
263 #[serde(default)]
265 name: Option<String>,
266 #[serde(default)]
268 ns: Option<String>,
269 #[serde(default)]
272 api_url: Option<String>,
273 #[serde(default)]
275 identity_url: Option<String>,
276 #[serde(default)]
278 client_id: Option<String>,
279 #[serde(default)]
281 client_secret: Option<String>,
282 #[serde(default)]
284 folder: Option<String>,
285 #[serde(default)]
288 password_env: Option<String>,
289 #[serde(default)]
290 overwrite: bool,
291 },
292 #[serde(rename = "kp")]
298 Keepass {
299 #[serde(default)]
301 name: Option<String>,
302 path: String,
304 #[serde(default)]
307 password_env: Option<String>,
308 #[serde(default)]
310 keyfile_path: Option<String>,
311 #[serde(default)]
315 group: Option<String>,
316 #[serde(default)]
319 recursive: Option<bool>,
320 #[serde(default)]
322 ns: Option<String>,
323 #[serde(default)]
324 overwrite: bool,
325 },
326}
327
328impl PullSource {
329 pub fn name(&self) -> Option<&str> {
331 match self {
332 PullSource::AzureKeyVault { name, .. }
333 | PullSource::HashiCorpVault { name, .. }
334 | PullSource::OnePassword { name, .. }
335 | PullSource::Aws { name, .. }
336 | PullSource::SsmParameterStore { name, .. }
337 | PullSource::Gcp { name, .. }
338 | PullSource::Bitwarden { name, .. } => name.as_deref(),
339 PullSource::Keepass { name, .. } => name.as_deref(),
340 }
341 }
342
343 pub fn ns(&self) -> Option<&str> {
347 match self {
348 PullSource::AzureKeyVault { ns, .. }
349 | PullSource::HashiCorpVault { ns, .. }
350 | PullSource::OnePassword { ns, .. }
351 | PullSource::Aws { ns, .. }
352 | PullSource::SsmParameterStore { ns, .. }
353 | PullSource::Gcp { ns, .. }
354 | PullSource::Bitwarden { ns, .. } => ns.as_deref(),
355 PullSource::Keepass { ns, .. } => ns.as_deref(),
356 }
357 }
358
359 pub fn provider_type(&self) -> &'static str {
361 match self {
362 PullSource::AzureKeyVault { .. } => "akv",
363 PullSource::HashiCorpVault { .. } => "hcp",
364 PullSource::OnePassword { .. } => "op",
365 PullSource::Aws { .. } => "aws",
366 PullSource::SsmParameterStore { .. } => "ssm",
367 PullSource::Gcp { .. } => "gcp",
368 PullSource::Bitwarden { .. } => "bw",
369 PullSource::Keepass { .. } => "kp",
370 }
371 }
372}
373
374fn default_hcp_addr() -> String {
375 "http://127.0.0.1:8200".into()
376}
377fn default_mount() -> String {
378 "secret".into()
379}
380
381pub fn find_config(start: &Path) -> Option<PathBuf> {
383 let mut dir = start.to_path_buf();
384 loop {
385 let yml = dir.join(".tsafe.yml");
386 if yml.exists() {
387 return Some(yml);
388 }
389 let json = dir.join(".tsafe.json");
390 if json.exists() {
391 return Some(json);
392 }
393 if !dir.pop() {
394 return None;
395 }
396 }
397}
398
399pub fn load(path: &Path) -> SafeResult<PullConfig> {
401 let content = std::fs::read_to_string(path)?;
402 let is_json = path
403 .extension()
404 .and_then(|e| e.to_str())
405 .map(|e| e == "json")
406 .unwrap_or(false);
407 if is_json {
408 serde_json::from_str(&content).map_err(|e| SafeError::InvalidVault {
409 reason: format!("invalid pull config JSON: {e}"),
410 })
411 } else {
412 serde_yaml::from_str(&content).map_err(|e| SafeError::InvalidVault {
413 reason: format!("invalid pull config YAML: {e}"),
414 })
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use tempfile::tempdir;
422
423 #[test]
424 fn parse_yaml_config() {
425 let yaml = r#"
426pulls:
427 - source: akv
428 vault_url: https://myvault.vault.azure.net
429 prefix: MYAPP_
430 overwrite: true
431 - source: hcp
432 addr: http://vault:8200
433 mount: secret
434 prefix: myapp/
435 - source: op
436 item: Database Credentials
437 op_vault: Infrastructure
438 - source: aws
439 region: us-east-1
440 prefix: myapp/
441 - source: gcp
442 project: my-gcp-project
443 prefix: myapp-
444"#;
445 let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
446 assert_eq!(cfg.pulls.len(), 5);
447 match &cfg.pulls[0] {
448 PullSource::AzureKeyVault {
449 vault_url,
450 prefix,
451 overwrite,
452 ..
453 } => {
454 assert_eq!(vault_url, "https://myvault.vault.azure.net");
455 assert_eq!(prefix.as_deref(), Some("MYAPP_"));
456 assert!(overwrite);
457 }
458 other => panic!("expected AzureKeyVault, got {other:?}"),
459 }
460 }
461
462 #[test]
463 fn parse_json_config() {
464 let json = r#"{"pulls": [{"source": "op", "item": "Test"}]}"#;
465 let cfg: PullConfig = serde_json::from_str(json).unwrap();
466 assert_eq!(cfg.pulls.len(), 1);
467 }
468
469 #[test]
470 fn find_config_walks_up() {
471 let dir = tempdir().unwrap();
472 let child = dir.path().join("a/b/c");
473 std::fs::create_dir_all(&child).unwrap();
474 let cfg_path = dir.path().join(".tsafe.yml");
475 std::fs::write(&cfg_path, "pulls: []").unwrap();
476 let found = find_config(&child).unwrap();
477 assert_eq!(found, cfg_path);
478 }
479
480 #[test]
481 fn find_config_returns_none() {
482 let dir = tempdir().unwrap();
483 assert!(find_config(dir.path()).is_none());
484 }
485
486 #[test]
488 fn parse_name_and_ns_fields() {
489 let yaml = r#"
490pulls:
491 - source: akv
492 name: prod-akv
493 ns: prod
494 vault_url: https://prod.vault.azure.net
495 - source: aws
496 name: staging-aws
497 ns: staging
498 region: us-east-1
499 - source: gcp
500 project: my-project
501"#;
502 let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
503 assert_eq!(cfg.pulls.len(), 3);
504
505 assert_eq!(cfg.pulls[0].name(), Some("prod-akv"));
506 assert_eq!(cfg.pulls[0].ns(), Some("prod"));
507 assert_eq!(cfg.pulls[0].provider_type(), "akv");
508
509 assert_eq!(cfg.pulls[1].name(), Some("staging-aws"));
510 assert_eq!(cfg.pulls[1].ns(), Some("staging"));
511 assert_eq!(cfg.pulls[1].provider_type(), "aws");
512
513 assert_eq!(cfg.pulls[2].name(), None);
515 assert_eq!(cfg.pulls[2].ns(), None);
516 assert_eq!(cfg.pulls[2].provider_type(), "gcp");
517 }
518
519 #[test]
521 fn name_and_ns_default_to_none() {
522 let yaml = r#"
523pulls:
524 - source: akv
525 vault_url: https://myvault.vault.azure.net
526 - source: hcp
527 addr: http://vault:8200
528 mount: secret
529 - source: op
530 item: MyItem
531 - source: aws
532 region: us-east-1
533 - source: ssm
534 region: us-east-1
535 - source: gcp
536 project: my-project
537"#;
538 let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
539 for source in &cfg.pulls {
540 assert_eq!(
541 source.name(),
542 None,
543 "expected no name for {:?}",
544 source.provider_type()
545 );
546 assert_eq!(
547 source.ns(),
548 None,
549 "expected no ns for {:?}",
550 source.provider_type()
551 );
552 }
553 }
554
555 #[test]
559 fn parse_hcp_token_auth_from_yaml() {
560 let yaml = r#"
561pulls:
562 - source: hcp
563 addr: https://vault.example.com:8200
564 auth:
565 method: token
566 token: hvs.my-static-token
567"#;
568 let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
569 assert_eq!(cfg.pulls.len(), 1);
570 match &cfg.pulls[0] {
571 PullSource::HashiCorpVault { auth, .. } => {
572 assert!(
573 matches!(
574 auth,
575 Some(VaultAuthConfig::Token {
576 token: Some(t)
577 }) if t == "hvs.my-static-token"
578 ),
579 "expected Token auth with static token, got {auth:?}"
580 );
581 }
582 other => panic!("expected HashiCorpVault, got {other:?}"),
583 }
584 }
585
586 #[test]
588 fn parse_hcp_approle_auth_from_yaml() {
589 let yaml = r#"
590pulls:
591 - source: hcp
592 addr: https://vault.example.com:8200
593 auth:
594 method: approle
595 role_id: my-role-123
596 secret_id: my-secret-456
597"#;
598 let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
599 match &cfg.pulls[0] {
600 PullSource::HashiCorpVault { auth, .. } => {
601 assert!(
602 matches!(
603 auth,
604 Some(VaultAuthConfig::Approle { role_id, secret_id })
605 if role_id == "my-role-123" && secret_id == "my-secret-456"
606 ),
607 "expected AppRole auth, got {auth:?}"
608 );
609 }
610 other => panic!("expected HashiCorpVault, got {other:?}"),
611 }
612 }
613
614 #[test]
616 fn parse_hcp_vault_namespace_from_yaml() {
617 let yaml = r#"
618pulls:
619 - source: hcp
620 addr: https://vault.example.com:8200
621 vault_namespace: team-alpha
622"#;
623 let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
624 match &cfg.pulls[0] {
625 PullSource::HashiCorpVault {
626 vault_namespace, ..
627 } => {
628 assert_eq!(vault_namespace.as_deref(), Some("team-alpha"));
629 }
630 other => panic!("expected HashiCorpVault, got {other:?}"),
631 }
632 }
633
634 #[test]
636 fn parse_hcp_defaults_auth_and_namespace_to_none() {
637 let yaml = r#"
638pulls:
639 - source: hcp
640 addr: http://127.0.0.1:8200
641"#;
642 let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
643 match &cfg.pulls[0] {
644 PullSource::HashiCorpVault {
645 auth,
646 vault_namespace,
647 ..
648 } => {
649 assert!(auth.is_none(), "expected auth=None, got {auth:?}");
650 assert!(
651 vault_namespace.is_none(),
652 "expected vault_namespace=None, got {vault_namespace:?}"
653 );
654 }
655 other => panic!("expected HashiCorpVault, got {other:?}"),
656 }
657 }
658
659 #[test]
663 fn expand_env_var_str_replaces_placeholder() {
664 temp_env::with_var("TEST_SECRET_ID", Some("s-abc-123"), || {
665 let result = expand_env_var_str("${TEST_SECRET_ID}");
666 assert_eq!(result, "s-abc-123");
667 });
668 }
669
670 #[test]
672 fn expand_env_var_str_no_placeholder_passthrough() {
673 let result = expand_env_var_str("plain-secret-id");
674 assert_eq!(result, "plain-secret-id");
675 }
676
677 #[test]
679 fn expand_env_var_str_unknown_var_left_as_is() {
680 temp_env::with_var("VAULT_UNKNOWN_9999", None::<&str>, || {
681 let result = expand_env_var_str("${VAULT_UNKNOWN_9999}");
682 assert_eq!(result, "${VAULT_UNKNOWN_9999}");
683 });
684 }
685
686 #[test]
688 fn vault_auth_config_expand_env_vars_in_approle() {
689 temp_env::with_var("MY_SECRET_ID", Some("expanded-sid"), || {
690 let auth = VaultAuthConfig::Approle {
691 role_id: "static-role".into(),
692 secret_id: "${MY_SECRET_ID}".into(),
693 };
694 let expanded = auth.expand_env_vars();
695 assert!(
696 matches!(
697 expanded,
698 VaultAuthConfig::Approle { ref role_id, ref secret_id }
699 if role_id == "static-role" && secret_id == "expanded-sid"
700 ),
701 "expected expanded secret_id, got {expanded:?}"
702 );
703 });
704 }
705
706 #[test]
708 fn vault_auth_config_expand_env_vars_token_unchanged() {
709 let auth = VaultAuthConfig::Token {
710 token: Some("hvs.static".into()),
711 };
712 let expanded = auth.expand_env_vars();
713 assert!(
714 matches!(expanded, VaultAuthConfig::Token { token: Some(ref t) } if t == "hvs.static"),
715 "expected token unchanged, got {expanded:?}"
716 );
717 }
718}