Skip to main content

ssh_commander_core/ssh/
host_keys.rs

1use anyhow::{Context, Result};
2use russh_keys::PublicKeyBase64;
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use tokio::sync::Mutex;
7
8/// Result of verifying a server-offered host key against the local store.
9#[derive(Debug, Clone)]
10pub enum Verdict {
11    /// Host is known and the key matches.
12    Known,
13    /// Host has never been seen — safe to TOFU-trust.
14    Unknown,
15    /// Host is known but the key has changed — refuse the connection.
16    Mismatch {
17        expected_fingerprint: String,
18        got_fingerprint: String,
19    },
20}
21
22/// Details surfaced to the user when `Verdict::Mismatch` caused a rejection.
23#[derive(Debug, Clone)]
24pub struct HostKeyMismatch {
25    pub host: String,
26    pub port: u16,
27    pub expected_fingerprint: String,
28    pub got_fingerprint: String,
29    pub store_path: PathBuf,
30}
31
32/// Details surfaced to the user when host-key verification could not complete
33/// because the trust store was unavailable.
34#[derive(Debug, Clone)]
35pub struct HostKeyStoreAccessError {
36    pub host: String,
37    pub port: u16,
38    pub store_path: PathBuf,
39    pub operation: &'static str,
40    pub source: String,
41}
42
43/// Any verification failure that should be surfaced verbosely to the caller.
44#[derive(Debug, Clone)]
45pub enum HostKeyVerificationFailure {
46    Mismatch(HostKeyMismatch),
47    StoreAccess(HostKeyStoreAccessError),
48}
49
50/// Slot that a `Client` instance writes a host-key verification failure into
51/// during the SSH handshake. The caller of `connect` reads it after the error
52/// to build a descriptive user-facing message.
53pub type VerificationFailureSlot = Arc<std::sync::Mutex<Option<HostKeyVerificationFailure>>>;
54
55/// Persistent store of trusted SSH host keys (analogous to `~/.ssh/known_hosts`).
56///
57/// Internally lazily loaded on first use. Safe to clone `Arc<HostKeyStore>`
58/// across many connections — all access is serialised through an async Mutex.
59pub struct HostKeyStore {
60    path: PathBuf,
61    state: Mutex<Option<HashMap<String, String>>>,
62}
63
64impl HostKeyStore {
65    pub fn new(path: PathBuf) -> Self {
66        Self {
67            path,
68            state: Mutex::new(None),
69        }
70    }
71
72    /// Default location: `$XDG_CONFIG_HOME/r-shell/known_hosts` (or platform
73    /// equivalent via `dirs::config_dir()`).
74    pub fn default_path() -> PathBuf {
75        dirs::config_dir()
76            .unwrap_or_else(std::env::temp_dir)
77            .join("r-shell")
78            .join("known_hosts")
79    }
80
81    pub fn path(&self) -> &Path {
82        &self.path
83    }
84
85    /// Check whether the server-offered key matches the stored fingerprint for
86    /// `(host, port)`. Does not mutate the store.
87    pub async fn verify(
88        &self,
89        host: &str,
90        port: u16,
91        key: &russh_keys::key::PublicKey,
92    ) -> Result<Verdict> {
93        let offered = key.public_key_base64();
94        let offered_fp = key.fingerprint();
95        let key_id = Self::make_key(host, port);
96
97        let mut guard = self.state.lock().await;
98        if guard.is_none() {
99            *guard = Some(Self::load_from_disk(&self.path).await?);
100        }
101        let entries = guard.as_ref().expect("state initialised above");
102
103        let verdict = match entries.get(&key_id) {
104            Some(stored) if stored == &offered => Verdict::Known,
105            Some(stored) => Verdict::Mismatch {
106                expected_fingerprint: fingerprint_from_stored(stored),
107                got_fingerprint: offered_fp,
108            },
109            None => Verdict::Unknown,
110        };
111
112        Ok(verdict)
113    }
114
115    /// Persist the server-offered key as trusted for `(host, port)`.
116    /// Creates the parent directory if missing.
117    pub async fn trust(
118        &self,
119        host: &str,
120        port: u16,
121        key: &russh_keys::key::PublicKey,
122    ) -> Result<()> {
123        let offered = key.public_key_base64();
124        let key_id = Self::make_key(host, port);
125
126        let mut guard = self.state.lock().await;
127        if guard.is_none() {
128            *guard = Some(Self::load_from_disk(&self.path).await?);
129        }
130
131        let mut snapshot = guard.as_ref().cloned().unwrap_or_default();
132        snapshot.insert(key_id, offered);
133
134        self.write_to_disk(&snapshot).await?;
135        *guard = Some(snapshot);
136        Ok(())
137    }
138
139    /// Forget a previously-trusted host. Returns `true` if an entry was
140    /// removed, `false` if there was nothing to remove. Used by the UI's
141    /// "Trust new key" flow on a `HostKeyMismatch`: forget the stale
142    /// entry, retry the connect, the next `verify()` falls through to
143    /// `Verdict::Unknown` and the new key is TOFU-trusted.
144    pub async fn forget(&self, host: &str, port: u16) -> Result<bool> {
145        let key_id = Self::make_key(host, port);
146
147        let mut guard = self.state.lock().await;
148        if guard.is_none() {
149            *guard = Some(Self::load_from_disk(&self.path).await?);
150        }
151
152        let mut snapshot = guard.as_ref().cloned().unwrap_or_default();
153        let removed = snapshot.remove(&key_id).is_some();
154        if removed {
155            self.write_to_disk(&snapshot).await?;
156            *guard = Some(snapshot);
157        }
158        Ok(removed)
159    }
160
161    /// Normalize host:port into a known_hosts-style key.
162    /// Non-default ports use the `[host]:port` form to match OpenSSH conventions.
163    fn make_key(host: &str, port: u16) -> String {
164        if port == 22 {
165            host.to_string()
166        } else {
167            format!("[{}]:{}", host, port)
168        }
169    }
170
171    async fn load_from_disk(path: &Path) -> Result<HashMap<String, String>> {
172        let mut map = HashMap::new();
173        let content = match tokio::fs::read_to_string(path).await {
174            Ok(s) => s,
175            Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(map),
176            Err(e) => {
177                return Err(e)
178                    .with_context(|| format!("failed to read known_hosts at {}", path.display()));
179            }
180        };
181
182        for line in content.lines() {
183            let trimmed = line.trim();
184            if trimmed.is_empty() || trimmed.starts_with('#') {
185                continue;
186            }
187            let mut parts = trimmed.splitn(2, char::is_whitespace);
188            if let (Some(host_id), Some(key_blob)) = (parts.next(), parts.next()) {
189                map.insert(host_id.to_string(), key_blob.trim().to_string());
190            }
191        }
192        Ok(map)
193    }
194
195    async fn write_to_disk(&self, entries: &HashMap<String, String>) -> Result<()> {
196        if let Some(parent) = self.path.parent() {
197            tokio::fs::create_dir_all(parent)
198                .await
199                .with_context(|| format!("failed to create {}", parent.display()))?;
200        }
201
202        let mut content =
203            String::from("# r-shell known hosts — auto-managed, one entry per host\n");
204        let mut keys: Vec<&String> = entries.keys().collect();
205        keys.sort();
206        for k in keys {
207            if let Some(v) = entries.get(k) {
208                content.push_str(k);
209                content.push(' ');
210                content.push_str(v);
211                content.push('\n');
212            }
213        }
214        tokio::fs::write(&self.path, content)
215            .await
216            .with_context(|| format!("failed to write {}", self.path.display()))?;
217        Ok(())
218    }
219}
220
221/// Compute an SHA-256 fingerprint from a stored base64 public-key blob,
222/// matching the format returned by `key::PublicKey::fingerprint()` so both
223/// sides of a mismatch display in the same form.
224fn fingerprint_from_stored(blob_b64: &str) -> String {
225    match russh_keys::parse_public_key_base64(blob_b64) {
226        Ok(key) => key.fingerprint(),
227        Err(_) => String::from("<unparseable stored key>"),
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use russh_keys::key::KeyPair;
235    use tempfile::TempDir;
236
237    fn temp_store() -> (TempDir, HostKeyStore) {
238        let dir = TempDir::new().expect("tmpdir");
239        let path = dir.path().join("known_hosts");
240        (dir, HostKeyStore::new(path))
241    }
242
243    #[test]
244    fn make_key_uses_bracket_form_for_non_default_port() {
245        assert_eq!(HostKeyStore::make_key("host", 22), "host");
246        assert_eq!(HostKeyStore::make_key("host", 2222), "[host]:2222");
247    }
248
249    #[tokio::test]
250    async fn unknown_host_yields_unknown_verdict() {
251        let (_dir, store) = temp_store();
252        // Load without any file present — should be empty, not an error.
253        let mut guard = store.state.lock().await;
254        *guard = Some(HostKeyStore::load_from_disk(store.path()).await.unwrap());
255        assert!(guard.as_ref().unwrap().is_empty());
256    }
257
258    fn test_public_key() -> russh_keys::key::PublicKey {
259        KeyPair::generate_ed25519()
260            .expect("generate keypair")
261            .clone_public_key()
262            .expect("clone public key")
263    }
264
265    #[tokio::test]
266    async fn verify_propagates_store_read_errors() {
267        let dir = TempDir::new().expect("tmpdir");
268        let store = HostKeyStore::new(dir.path().to_path_buf());
269
270        let err = store
271            .verify("host", 22, &test_public_key())
272            .await
273            .expect_err("directory path must not be treated as an empty store");
274
275        assert!(err.to_string().contains("failed to read known_hosts"));
276    }
277
278    #[tokio::test]
279    async fn trust_does_not_cache_keys_when_write_fails() {
280        let dir = TempDir::new().expect("tmpdir");
281        let file_parent = dir.path().join("not-a-dir");
282        std::fs::write(&file_parent, "regular file").expect("write blocker file");
283        let store = HostKeyStore::new(file_parent.join("known_hosts"));
284        let key = test_public_key();
285
286        store
287            .trust("host", 22, &key)
288            .await
289            .expect_err("write should fail when parent is not a directory");
290
291        let guard = store.state.lock().await;
292        assert!(
293            guard.as_ref().is_none_or(HashMap::is_empty),
294            "failed trust must not mark the key as known in memory",
295        );
296    }
297}