Skip to main content

ssm_core/
ssh_config.rs

1use crate::config::{Config, ConfigError, Host};
2use std::path::Path;
3
4/// Write `content` to `path` atomically by first writing a `.tmp` file then
5/// renaming it over the target. On POSIX, rename(2) is atomic so readers
6/// always see either the old or the new file, never a partial write.
7fn atomic_write(path: &Path, content: &str) -> Result<(), std::io::Error> {
8    let tmp_path = path.with_extension("tmp");
9    std::fs::write(&tmp_path, content)?;
10    std::fs::rename(&tmp_path, path)?;
11    Ok(())
12}
13
14pub fn generate_host_block(host: &Host) -> String {
15    let mut lines = vec![format!("Host {}", host.alias)];
16    lines.push(format!("    HostName {}", host.hostname));
17    if let Some(ref user) = host.user {
18        lines.push(format!("    User {}", user));
19    }
20    if host.port != 22 {
21        lines.push(format!("    Port {}", host.port));
22    }
23    if let Some(ref key) = host.identity_file {
24        lines.push(format!("    IdentityFile {}", key.display()));
25    }
26    lines.join("\n")
27}
28
29pub fn generate_config(config: &Config) -> String {
30    let header = "# Generated by ssm. Do not edit manually.\n";
31    let blocks: Vec<String> = config.hosts.iter().map(generate_host_block).collect();
32    format!("{}{}\n", header, blocks.join("\n\n"))
33}
34
35pub fn write_generated_config(config: &Config) -> Result<(), ConfigError> {
36    let path = &config.settings.generated_config_path;
37    if let Some(parent) = path.parent() {
38        std::fs::create_dir_all(parent)?;
39    }
40    atomic_write(path, &generate_config(config))?;
41    Ok(())
42}
43
44const INCLUDE_DIRECTIVE: &str = "Include ssm-hosts.conf";
45
46pub fn ensure_include_directive(ssh_config_path: &Path) -> Result<(), ConfigError> {
47    if let Some(parent) = ssh_config_path.parent() {
48        std::fs::create_dir_all(parent)?;
49    }
50
51    if ssh_config_path.exists() {
52        let content = std::fs::read_to_string(ssh_config_path)?;
53        if content.lines().any(|line| line.trim() == INCLUDE_DIRECTIVE) {
54            return Ok(());
55        }
56        let new_content = format!("{}\n\n{}", INCLUDE_DIRECTIVE, content);
57        atomic_write(ssh_config_path, &new_content)?;
58    } else {
59        atomic_write(ssh_config_path, &format!("{}\n", INCLUDE_DIRECTIVE))?;
60    }
61    Ok(())
62}
63
64pub fn sync_ssh_config(config: &Config) -> Result<(), ConfigError> {
65    write_generated_config(config)?;
66    ensure_include_directive(&config.settings.ssh_config_path)?;
67    Ok(())
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use crate::config::{Host, Settings};
74    use std::path::PathBuf;
75    use tempfile::TempDir;
76
77    fn test_host() -> Host {
78        Host {
79            alias: "prod-api".into(),
80            hostname: "10.0.1.50".into(),
81            user: Some("deploy".into()),
82            port: 22,
83            identity_file: Some(PathBuf::from("~/.ssh/id_ed25519")),
84            tags: vec!["prod".into()],
85            notes: None,
86            tunnels: vec![],
87            commands: vec![],
88        }
89    }
90
91    #[test]
92    fn test_generate_host_block_basic() {
93        let host = test_host();
94        let block = generate_host_block(&host);
95        assert!(block.contains("Host prod-api"));
96        assert!(block.contains("HostName 10.0.1.50"));
97        assert!(block.contains("User deploy"));
98        assert!(block.contains("IdentityFile ~/.ssh/id_ed25519"));
99        assert!(!block.contains("Port"));
100    }
101
102    #[test]
103    fn test_generate_host_block_custom_port() {
104        let mut host = test_host();
105        host.port = 2222;
106        let block = generate_host_block(&host);
107        assert!(block.contains("Port 2222"));
108    }
109
110    #[test]
111    fn test_generate_config_multiple_hosts() {
112        let config = Config {
113            settings: Settings::default(),
114            hosts: vec![test_host(), {
115                let mut h = test_host();
116                h.alias = "staging".into();
117                h.hostname = "10.0.2.50".into();
118                h
119            }],
120            scenarios: vec![],
121        };
122        let output = generate_config(&config);
123        assert!(output.starts_with("# Generated by ssm"));
124        assert!(output.contains("Host prod-api"));
125        assert!(output.contains("Host staging"));
126    }
127
128    #[test]
129    fn test_ensure_include_creates_new_file() {
130        let dir = TempDir::new().unwrap();
131        let ssh_config = dir.path().join("config");
132        ensure_include_directive(&ssh_config).unwrap();
133        let content = std::fs::read_to_string(&ssh_config).unwrap();
134        assert_eq!(content.trim(), INCLUDE_DIRECTIVE);
135    }
136
137    #[test]
138    fn test_ensure_include_prepends_to_existing() {
139        let dir = TempDir::new().unwrap();
140        let ssh_config = dir.path().join("config");
141        std::fs::write(&ssh_config, "Host myserver\n    HostName 1.2.3.4\n").unwrap();
142        ensure_include_directive(&ssh_config).unwrap();
143        let content = std::fs::read_to_string(&ssh_config).unwrap();
144        assert!(content.starts_with(INCLUDE_DIRECTIVE));
145        assert!(content.contains("Host myserver"));
146    }
147
148    #[test]
149    fn test_ensure_include_idempotent() {
150        let dir = TempDir::new().unwrap();
151        let ssh_config = dir.path().join("config");
152        std::fs::write(
153            &ssh_config,
154            format!("{}\n\nHost myserver\n", INCLUDE_DIRECTIVE),
155        )
156        .unwrap();
157        ensure_include_directive(&ssh_config).unwrap();
158        let content = std::fs::read_to_string(&ssh_config).unwrap();
159        assert_eq!(
160            content.matches(INCLUDE_DIRECTIVE).count(),
161            1,
162            "Include should not be duplicated"
163        );
164    }
165
166    #[test]
167    fn test_sync_ssh_config_end_to_end() {
168        let dir = TempDir::new().unwrap();
169        let config = Config {
170            settings: Settings {
171                ssh_config_path: dir.path().join("config"),
172                generated_config_path: dir.path().join("ssm-hosts.conf"),
173            },
174            hosts: vec![test_host()],
175            scenarios: vec![],
176        };
177        sync_ssh_config(&config).unwrap();
178
179        let generated = std::fs::read_to_string(dir.path().join("ssm-hosts.conf")).unwrap();
180        assert!(generated.contains("Host prod-api"));
181
182        let ssh_config = std::fs::read_to_string(dir.path().join("config")).unwrap();
183        assert!(ssh_config.contains(INCLUDE_DIRECTIVE));
184    }
185}