ssh_commander_core/ssh/
host_keys.rs1use anyhow::{Context, Result};
2use russh_keys::PublicKeyBase64;
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use tokio::sync::Mutex;
7
8#[derive(Debug, Clone)]
10pub enum Verdict {
11 Known,
13 Unknown,
15 Mismatch {
17 expected_fingerprint: String,
18 got_fingerprint: String,
19 },
20}
21
22#[derive(Debug, Clone)]
24pub struct HostKeyMismatch {
25 pub host: String,
26 pub port: u16,
27 pub expected_fingerprint: String,
28 pub got_fingerprint: String,
29 pub store_path: PathBuf,
30}
31
32#[derive(Debug, Clone)]
35pub struct HostKeyStoreAccessError {
36 pub host: String,
37 pub port: u16,
38 pub store_path: PathBuf,
39 pub operation: &'static str,
40 pub source: String,
41}
42
43#[derive(Debug, Clone)]
45pub enum HostKeyVerificationFailure {
46 Mismatch(HostKeyMismatch),
47 StoreAccess(HostKeyStoreAccessError),
48}
49
50pub type VerificationFailureSlot = Arc<std::sync::Mutex<Option<HostKeyVerificationFailure>>>;
54
55pub struct HostKeyStore {
60 path: PathBuf,
61 state: Mutex<Option<HashMap<String, String>>>,
62}
63
64impl HostKeyStore {
65 pub fn new(path: PathBuf) -> Self {
66 Self {
67 path,
68 state: Mutex::new(None),
69 }
70 }
71
72 pub fn default_path() -> PathBuf {
75 dirs::config_dir()
76 .unwrap_or_else(std::env::temp_dir)
77 .join("r-shell")
78 .join("known_hosts")
79 }
80
81 pub fn path(&self) -> &Path {
82 &self.path
83 }
84
85 pub async fn verify(
88 &self,
89 host: &str,
90 port: u16,
91 key: &russh_keys::key::PublicKey,
92 ) -> Result<Verdict> {
93 let offered = key.public_key_base64();
94 let offered_fp = key.fingerprint();
95 let key_id = Self::make_key(host, port);
96
97 let mut guard = self.state.lock().await;
98 if guard.is_none() {
99 *guard = Some(Self::load_from_disk(&self.path).await?);
100 }
101 let entries = guard.as_ref().expect("state initialised above");
102
103 let verdict = match entries.get(&key_id) {
104 Some(stored) if stored == &offered => Verdict::Known,
105 Some(stored) => Verdict::Mismatch {
106 expected_fingerprint: fingerprint_from_stored(stored),
107 got_fingerprint: offered_fp,
108 },
109 None => Verdict::Unknown,
110 };
111
112 Ok(verdict)
113 }
114
115 pub async fn trust(
118 &self,
119 host: &str,
120 port: u16,
121 key: &russh_keys::key::PublicKey,
122 ) -> Result<()> {
123 let offered = key.public_key_base64();
124 let key_id = Self::make_key(host, port);
125
126 let mut guard = self.state.lock().await;
127 if guard.is_none() {
128 *guard = Some(Self::load_from_disk(&self.path).await?);
129 }
130
131 let mut snapshot = guard.as_ref().cloned().unwrap_or_default();
132 snapshot.insert(key_id, offered);
133
134 self.write_to_disk(&snapshot).await?;
135 *guard = Some(snapshot);
136 Ok(())
137 }
138
139 pub async fn forget(&self, host: &str, port: u16) -> Result<bool> {
145 let key_id = Self::make_key(host, port);
146
147 let mut guard = self.state.lock().await;
148 if guard.is_none() {
149 *guard = Some(Self::load_from_disk(&self.path).await?);
150 }
151
152 let mut snapshot = guard.as_ref().cloned().unwrap_or_default();
153 let removed = snapshot.remove(&key_id).is_some();
154 if removed {
155 self.write_to_disk(&snapshot).await?;
156 *guard = Some(snapshot);
157 }
158 Ok(removed)
159 }
160
161 fn make_key(host: &str, port: u16) -> String {
164 if port == 22 {
165 host.to_string()
166 } else {
167 format!("[{}]:{}", host, port)
168 }
169 }
170
171 async fn load_from_disk(path: &Path) -> Result<HashMap<String, String>> {
172 let mut map = HashMap::new();
173 let content = match tokio::fs::read_to_string(path).await {
174 Ok(s) => s,
175 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(map),
176 Err(e) => {
177 return Err(e)
178 .with_context(|| format!("failed to read known_hosts at {}", path.display()));
179 }
180 };
181
182 for line in content.lines() {
183 let trimmed = line.trim();
184 if trimmed.is_empty() || trimmed.starts_with('#') {
185 continue;
186 }
187 let mut parts = trimmed.splitn(2, char::is_whitespace);
188 if let (Some(host_id), Some(key_blob)) = (parts.next(), parts.next()) {
189 map.insert(host_id.to_string(), key_blob.trim().to_string());
190 }
191 }
192 Ok(map)
193 }
194
195 async fn write_to_disk(&self, entries: &HashMap<String, String>) -> Result<()> {
196 if let Some(parent) = self.path.parent() {
197 tokio::fs::create_dir_all(parent)
198 .await
199 .with_context(|| format!("failed to create {}", parent.display()))?;
200 }
201
202 let mut content =
203 String::from("# r-shell known hosts — auto-managed, one entry per host\n");
204 let mut keys: Vec<&String> = entries.keys().collect();
205 keys.sort();
206 for k in keys {
207 if let Some(v) = entries.get(k) {
208 content.push_str(k);
209 content.push(' ');
210 content.push_str(v);
211 content.push('\n');
212 }
213 }
214 tokio::fs::write(&self.path, content)
215 .await
216 .with_context(|| format!("failed to write {}", self.path.display()))?;
217 Ok(())
218 }
219}
220
221fn fingerprint_from_stored(blob_b64: &str) -> String {
225 match russh_keys::parse_public_key_base64(blob_b64) {
226 Ok(key) => key.fingerprint(),
227 Err(_) => String::from("<unparseable stored key>"),
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use russh_keys::key::KeyPair;
235 use tempfile::TempDir;
236
237 fn temp_store() -> (TempDir, HostKeyStore) {
238 let dir = TempDir::new().expect("tmpdir");
239 let path = dir.path().join("known_hosts");
240 (dir, HostKeyStore::new(path))
241 }
242
243 #[test]
244 fn make_key_uses_bracket_form_for_non_default_port() {
245 assert_eq!(HostKeyStore::make_key("host", 22), "host");
246 assert_eq!(HostKeyStore::make_key("host", 2222), "[host]:2222");
247 }
248
249 #[tokio::test]
250 async fn unknown_host_yields_unknown_verdict() {
251 let (_dir, store) = temp_store();
252 let mut guard = store.state.lock().await;
254 *guard = Some(HostKeyStore::load_from_disk(store.path()).await.unwrap());
255 assert!(guard.as_ref().unwrap().is_empty());
256 }
257
258 fn test_public_key() -> russh_keys::key::PublicKey {
259 KeyPair::generate_ed25519()
260 .expect("generate keypair")
261 .clone_public_key()
262 .expect("clone public key")
263 }
264
265 #[tokio::test]
266 async fn verify_propagates_store_read_errors() {
267 let dir = TempDir::new().expect("tmpdir");
268 let store = HostKeyStore::new(dir.path().to_path_buf());
269
270 let err = store
271 .verify("host", 22, &test_public_key())
272 .await
273 .expect_err("directory path must not be treated as an empty store");
274
275 assert!(err.to_string().contains("failed to read known_hosts"));
276 }
277
278 #[tokio::test]
279 async fn trust_does_not_cache_keys_when_write_fails() {
280 let dir = TempDir::new().expect("tmpdir");
281 let file_parent = dir.path().join("not-a-dir");
282 std::fs::write(&file_parent, "regular file").expect("write blocker file");
283 let store = HostKeyStore::new(file_parent.join("known_hosts"));
284 let key = test_public_key();
285
286 store
287 .trust("host", 22, &key)
288 .await
289 .expect_err("write should fail when parent is not a directory");
290
291 let guard = store.state.lock().await;
292 assert!(
293 guard.as_ref().is_none_or(HashMap::is_empty),
294 "failed trust must not mark the key as known in memory",
295 );
296 }
297}