1use crate::config::{Config, ConfigError, Host};
2use std::path::Path;
3
4fn atomic_write(path: &Path, content: &str) -> Result<(), std::io::Error> {
8 let tmp_path = path.with_extension("tmp");
9 std::fs::write(&tmp_path, content)?;
10 std::fs::rename(&tmp_path, path)?;
11 Ok(())
12}
13
14pub fn generate_host_block(host: &Host) -> String {
15 let mut lines = vec![format!("Host {}", host.alias)];
16 lines.push(format!(" HostName {}", host.hostname));
17 if let Some(ref user) = host.user {
18 lines.push(format!(" User {}", user));
19 }
20 if host.port != 22 {
21 lines.push(format!(" Port {}", host.port));
22 }
23 if let Some(ref key) = host.identity_file {
24 lines.push(format!(" IdentityFile {}", key.display()));
25 }
26 lines.join("\n")
27}
28
29pub fn generate_config(config: &Config) -> String {
30 let header = "# Generated by ssm. Do not edit manually.\n";
31 let blocks: Vec<String> = config.hosts.iter().map(generate_host_block).collect();
32 format!("{}{}\n", header, blocks.join("\n\n"))
33}
34
35pub fn write_generated_config(config: &Config) -> Result<(), ConfigError> {
36 let path = &config.settings.generated_config_path;
37 if let Some(parent) = path.parent() {
38 std::fs::create_dir_all(parent)?;
39 }
40 atomic_write(path, &generate_config(config))?;
41 Ok(())
42}
43
44const INCLUDE_DIRECTIVE: &str = "Include ssm-hosts.conf";
45
46pub fn ensure_include_directive(ssh_config_path: &Path) -> Result<(), ConfigError> {
47 if let Some(parent) = ssh_config_path.parent() {
48 std::fs::create_dir_all(parent)?;
49 }
50
51 if ssh_config_path.exists() {
52 let content = std::fs::read_to_string(ssh_config_path)?;
53 if content.lines().any(|line| line.trim() == INCLUDE_DIRECTIVE) {
54 return Ok(());
55 }
56 let new_content = format!("{}\n\n{}", INCLUDE_DIRECTIVE, content);
57 atomic_write(ssh_config_path, &new_content)?;
58 } else {
59 atomic_write(ssh_config_path, &format!("{}\n", INCLUDE_DIRECTIVE))?;
60 }
61 Ok(())
62}
63
64pub fn sync_ssh_config(config: &Config) -> Result<(), ConfigError> {
65 write_generated_config(config)?;
66 ensure_include_directive(&config.settings.ssh_config_path)?;
67 Ok(())
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use crate::config::{Host, Settings};
74 use std::path::PathBuf;
75 use tempfile::TempDir;
76
77 fn test_host() -> Host {
78 Host {
79 alias: "prod-api".into(),
80 hostname: "10.0.1.50".into(),
81 user: Some("deploy".into()),
82 port: 22,
83 identity_file: Some(PathBuf::from("~/.ssh/id_ed25519")),
84 tags: vec!["prod".into()],
85 notes: None,
86 tunnels: vec![],
87 commands: vec![],
88 }
89 }
90
91 #[test]
92 fn test_generate_host_block_basic() {
93 let host = test_host();
94 let block = generate_host_block(&host);
95 assert!(block.contains("Host prod-api"));
96 assert!(block.contains("HostName 10.0.1.50"));
97 assert!(block.contains("User deploy"));
98 assert!(block.contains("IdentityFile ~/.ssh/id_ed25519"));
99 assert!(!block.contains("Port"));
100 }
101
102 #[test]
103 fn test_generate_host_block_custom_port() {
104 let mut host = test_host();
105 host.port = 2222;
106 let block = generate_host_block(&host);
107 assert!(block.contains("Port 2222"));
108 }
109
110 #[test]
111 fn test_generate_config_multiple_hosts() {
112 let config = Config {
113 settings: Settings::default(),
114 hosts: vec![test_host(), {
115 let mut h = test_host();
116 h.alias = "staging".into();
117 h.hostname = "10.0.2.50".into();
118 h
119 }],
120 scenarios: vec![],
121 };
122 let output = generate_config(&config);
123 assert!(output.starts_with("# Generated by ssm"));
124 assert!(output.contains("Host prod-api"));
125 assert!(output.contains("Host staging"));
126 }
127
128 #[test]
129 fn test_ensure_include_creates_new_file() {
130 let dir = TempDir::new().unwrap();
131 let ssh_config = dir.path().join("config");
132 ensure_include_directive(&ssh_config).unwrap();
133 let content = std::fs::read_to_string(&ssh_config).unwrap();
134 assert_eq!(content.trim(), INCLUDE_DIRECTIVE);
135 }
136
137 #[test]
138 fn test_ensure_include_prepends_to_existing() {
139 let dir = TempDir::new().unwrap();
140 let ssh_config = dir.path().join("config");
141 std::fs::write(&ssh_config, "Host myserver\n HostName 1.2.3.4\n").unwrap();
142 ensure_include_directive(&ssh_config).unwrap();
143 let content = std::fs::read_to_string(&ssh_config).unwrap();
144 assert!(content.starts_with(INCLUDE_DIRECTIVE));
145 assert!(content.contains("Host myserver"));
146 }
147
148 #[test]
149 fn test_ensure_include_idempotent() {
150 let dir = TempDir::new().unwrap();
151 let ssh_config = dir.path().join("config");
152 std::fs::write(
153 &ssh_config,
154 format!("{}\n\nHost myserver\n", INCLUDE_DIRECTIVE),
155 )
156 .unwrap();
157 ensure_include_directive(&ssh_config).unwrap();
158 let content = std::fs::read_to_string(&ssh_config).unwrap();
159 assert_eq!(
160 content.matches(INCLUDE_DIRECTIVE).count(),
161 1,
162 "Include should not be duplicated"
163 );
164 }
165
166 #[test]
167 fn test_sync_ssh_config_end_to_end() {
168 let dir = TempDir::new().unwrap();
169 let config = Config {
170 settings: Settings {
171 ssh_config_path: dir.path().join("config"),
172 generated_config_path: dir.path().join("ssm-hosts.conf"),
173 },
174 hosts: vec![test_host()],
175 scenarios: vec![],
176 };
177 sync_ssh_config(&config).unwrap();
178
179 let generated = std::fs::read_to_string(dir.path().join("ssm-hosts.conf")).unwrap();
180 assert!(generated.contains("Host prod-api"));
181
182 let ssh_config = std::fs::read_to_string(dir.path().join("config")).unwrap();
183 assert!(ssh_config.contains(INCLUDE_DIRECTIVE));
184 }
185}