Skip to main content

ssm_core/
config.rs

1use serde::{Deserialize, Serialize};
2use std::path::PathBuf;
3use thiserror::Error;
4
5#[derive(Error, Debug)]
6pub enum ConfigError {
7    #[error("failed to read config: {0}")]
8    Read(#[from] std::io::Error),
9    #[error("failed to parse config: {0}")]
10    Parse(#[from] toml::de::Error),
11    #[error("failed to serialize config: {0}")]
12    Serialize(#[from] toml::ser::Error),
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub struct Settings {
17    pub ssh_config_path: PathBuf,
18    pub generated_config_path: PathBuf,
19}
20
21impl Default for Settings {
22    fn default() -> Self {
23        let ssh_dir = dirs::home_dir()
24            .unwrap_or_else(|| PathBuf::from("~"))
25            .join(".ssh");
26        Self {
27            ssh_config_path: ssh_dir.join("config"),
28            generated_config_path: ssh_dir.join("ssm-hosts.conf"),
29        }
30    }
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
34pub struct TunnelConfig {
35    pub name: String,
36    pub local_port: u16,
37    pub remote_host: String,
38    pub remote_port: u16,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42pub struct CommandConfig {
43    pub name: String,
44    pub command: String,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
48pub struct Host {
49    pub alias: String,
50    pub hostname: String,
51    pub user: Option<String>,
52    #[serde(default = "default_port")]
53    pub port: u16,
54    pub identity_file: Option<PathBuf>,
55    #[serde(default)]
56    pub tags: Vec<String>,
57    pub notes: Option<String>,
58    #[serde(default)]
59    pub tunnels: Vec<TunnelConfig>,
60    #[serde(default)]
61    pub commands: Vec<CommandConfig>,
62}
63
64fn default_port() -> u16 {
65    22
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
69pub struct ScenarioTunnel {
70    pub host: String,
71    pub tunnel: String,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
75pub struct Scenario {
76    pub name: String,
77    pub tunnels: Vec<ScenarioTunnel>,
78}
79
80#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
81pub struct Config {
82    #[serde(default)]
83    pub settings: Settings,
84    #[serde(default)]
85    pub hosts: Vec<Host>,
86    #[serde(default)]
87    pub scenarios: Vec<Scenario>,
88}
89
90impl Config {
91    pub fn config_dir() -> Result<PathBuf, ConfigError> {
92        let dir = dirs::config_dir()
93            .unwrap_or_else(|| PathBuf::from("~/.config"))
94            .join("ssm");
95        Ok(dir)
96    }
97
98    pub fn default_path() -> Result<PathBuf, ConfigError> {
99        Ok(Self::config_dir()?.join("config.toml"))
100    }
101
102    pub fn load(path: &std::path::Path) -> Result<Self, ConfigError> {
103        if !path.exists() {
104            return Ok(Self::default());
105        }
106        let content = std::fs::read_to_string(path)?;
107        let config: Config = toml::from_str(&content)?;
108        Ok(config)
109    }
110
111    pub fn save(&self, path: &std::path::Path) -> Result<(), ConfigError> {
112        if let Some(parent) = path.parent() {
113            std::fs::create_dir_all(parent)?;
114        }
115        let content = toml::to_string_pretty(self)?;
116        std::fs::write(path, content)?;
117        Ok(())
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use tempfile::TempDir;
125
126    #[test]
127    fn test_default_config_has_empty_hosts() {
128        let config = Config::default();
129        assert!(config.hosts.is_empty());
130    }
131
132    #[test]
133    fn test_roundtrip_empty_config() {
134        let dir = TempDir::new().unwrap();
135        let path = dir.path().join("config.toml");
136        let config = Config::default();
137        config.save(&path).unwrap();
138        let loaded = Config::load(&path).unwrap();
139        assert_eq!(loaded.hosts.len(), 0);
140    }
141
142    #[test]
143    fn test_roundtrip_with_hosts() {
144        let dir = TempDir::new().unwrap();
145        let path = dir.path().join("config.toml");
146        let config = Config {
147            settings: Settings::default(),
148            hosts: vec![Host {
149                alias: "prod-api".into(),
150                hostname: "10.0.1.50".into(),
151                user: Some("deploy".into()),
152                port: 22,
153                identity_file: Some(PathBuf::from("~/.ssh/id_ed25519")),
154                tags: vec!["prod".into(), "api".into()],
155                notes: Some("Main API server".into()),
156                tunnels: vec![TunnelConfig {
157                    name: "postgres".into(),
158                    local_port: 5432,
159                    remote_host: "localhost".into(),
160                    remote_port: 5432,
161                }],
162                commands: vec![CommandConfig {
163                    name: "logs".into(),
164                    command: "tail -f /var/log/app/api.log".into(),
165                }],
166            }],
167            scenarios: vec![],
168        };
169        config.save(&path).unwrap();
170        let loaded = Config::load(&path).unwrap();
171        assert_eq!(config, loaded);
172    }
173
174    #[test]
175    fn test_load_nonexistent_returns_default() {
176        let dir = TempDir::new().unwrap();
177        let path = dir.path().join("does-not-exist.toml");
178        let config = Config::load(&path).unwrap();
179        assert!(config.hosts.is_empty());
180    }
181
182    #[test]
183    fn test_default_port_is_22() {
184        let toml_str = r#"
185[[hosts]]
186alias = "test"
187hostname = "1.2.3.4"
188"#;
189        let config: Config = toml::from_str(toml_str).unwrap();
190        assert_eq!(config.hosts[0].port, 22);
191    }
192}