1use ssh_key::{PublicKey as SshPublicKey, PrivateKey as SshPrivateKey, Algorithm};
4use std::{
5 fs,
6 path::{Path, PathBuf},
7};
8use thiserror::Error;
9
10#[derive(Error, Debug)]
12pub enum SshKeyError {
13 #[error("No SSH directory found")]
14 NoSshDirectory,
15 #[error("No suitable public keys found")]
16 NoPublicKeysFound,
17 #[error("Invalid public key format: {0}")]
18 InvalidKeyFormat(String),
19 #[error("Unsupported key algorithm: {0}")]
20 UnsupportedAlgorithm(String),
21 #[error("IO error: {0}")]
22 IoError(#[from] std::io::Error),
23 #[error("SSH key parsing error: {0}")]
24 SshKeyError(#[from] ssh_key::Error),
25 #[error("Key generation failed: {0}")]
26 KeyGenerationFailed(String),
27}
28
29#[derive(Debug, Clone, PartialEq)]
31pub enum KeyAlgorithm {
32 Rsa,
33 EcdsaP256,
34 Ed25519,
35}
36
37impl std::fmt::Display for KeyAlgorithm {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 KeyAlgorithm::Rsa => write!(f, "RSA"),
41 KeyAlgorithm::EcdsaP256 => write!(f, "ECDSA-P256"),
42 KeyAlgorithm::Ed25519 => write!(f, "Ed25519"),
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct HybridPublicKey {
50 pub ssh_key: SshPublicKey,
52 pub algorithm: KeyAlgorithm,
54 pub file_path: PathBuf,
56 pub comment: String,
58}
59
60impl HybridPublicKey {
61 pub fn new(ssh_key: SshPublicKey, file_path: PathBuf) -> Result<Self, SshKeyError> {
63 let algorithm = match ssh_key.algorithm() {
64 Algorithm::Rsa { .. } => KeyAlgorithm::Rsa,
65 Algorithm::Ecdsa { curve } => {
66 match curve.as_str() {
67 "nistp256" => KeyAlgorithm::EcdsaP256,
68 _ => return Err(SshKeyError::UnsupportedAlgorithm(curve.to_string())),
69 }
70 }
71 Algorithm::Ed25519 => KeyAlgorithm::Ed25519,
72 alg => return Err(SshKeyError::UnsupportedAlgorithm(alg.to_string())),
73 };
74
75 let comment = ssh_key.comment().to_string();
76
77 Ok(Self {
78 ssh_key,
79 algorithm,
80 file_path,
81 comment,
82 })
83 }
84
85 pub fn display_name(&self) -> String {
87 let filename = self.file_path.file_name()
88 .and_then(|n| n.to_str())
89 .unwrap_or("unknown");
90
91 if self.comment.is_empty() {
92 format!("{} ({})", filename, self.algorithm)
93 } else {
94 format!("{} ({}) - {}", filename, self.algorithm, self.comment)
95 }
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct HybridPrivateKey {
102 pub ssh_key: SshPrivateKey,
104 pub algorithm: KeyAlgorithm,
106 pub file_path: PathBuf,
108 pub comment: String,
110}
111
112impl HybridPrivateKey {
113 pub fn new(ssh_key: SshPrivateKey, file_path: PathBuf) -> Result<Self, SshKeyError> {
115 let algorithm = match ssh_key.algorithm() {
116 Algorithm::Rsa { .. } => KeyAlgorithm::Rsa,
117 Algorithm::Ecdsa { curve } => {
118 match curve.as_str() {
119 "nistp256" => KeyAlgorithm::EcdsaP256,
120 _ => return Err(SshKeyError::UnsupportedAlgorithm(curve.to_string())),
121 }
122 }
123 Algorithm::Ed25519 => KeyAlgorithm::Ed25519,
124 alg => return Err(SshKeyError::UnsupportedAlgorithm(alg.to_string())),
125 };
126
127 let comment = ssh_key.comment().to_string();
128
129 Ok(Self {
130 ssh_key,
131 algorithm,
132 file_path,
133 comment,
134 })
135 }
136
137 pub fn display_name(&self) -> String {
139 let filename = self.file_path.file_name()
140 .and_then(|n| n.to_str())
141 .unwrap_or("unknown");
142
143 if self.comment.is_empty() {
144 format!("{} ({})", filename, self.algorithm)
145 } else {
146 format!("{} ({}) - {}", filename, self.algorithm, self.comment)
147 }
148 }
149
150 pub fn public_key(&self) -> HybridPublicKey {
152 let public_ssh_key = self.ssh_key.public_key().clone();
153 HybridPublicKey::new(public_ssh_key, self.file_path.clone())
155 .expect("Failed to create public key from validated private key")
156 }
157}
158
159pub struct SshKeyDiscovery {
161 ssh_dir: PathBuf,
162}
163
164impl Default for SshKeyDiscovery {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170impl SshKeyDiscovery {
171 pub fn new() -> Self {
173 let ssh_dir = dirs::home_dir()
174 .map(|home| home.join(".ssh"))
175 .unwrap_or_else(|| PathBuf::from(".ssh"));
176
177 Self { ssh_dir }
178 }
179
180 pub fn with_ssh_dir<P: AsRef<Path>>(ssh_dir: P) -> Self {
182 Self {
183 ssh_dir: ssh_dir.as_ref().to_path_buf(),
184 }
185 }
186
187 pub fn discover_keys(&self) -> Result<Vec<HybridPublicKey>, SshKeyError> {
189 if !self.ssh_dir.exists() {
190 return Err(SshKeyError::NoSshDirectory);
191 }
192
193 let mut keys = Vec::new();
194 let entries = fs::read_dir(&self.ssh_dir)?;
195
196 for entry in entries {
197 let entry = entry?;
198 let path = entry.path();
199
200 if let Some(extension) = path.extension() {
202 if extension == "pub" {
203 match self.load_public_key(&path) {
204 Ok(key) => {
205 println!("š Found public key: {}", key.display_name());
206 keys.push(key);
207 },
208 Err(e) => {
209 eprintln!("Warning: Failed to load key {}: {}", path.display(), e);
211 }
212 }
213 }
214 }
215 }
216
217 if keys.is_empty() {
218 return Err(SshKeyError::NoPublicKeysFound);
219 }
220
221 keys.sort_by(|a, b| {
223 match (&a.algorithm, &b.algorithm) {
224 (KeyAlgorithm::Rsa, KeyAlgorithm::EcdsaP256) => std::cmp::Ordering::Less,
225 (KeyAlgorithm::Rsa, KeyAlgorithm::Ed25519) => std::cmp::Ordering::Less,
226 (KeyAlgorithm::EcdsaP256, KeyAlgorithm::Rsa) => std::cmp::Ordering::Greater,
227 (KeyAlgorithm::EcdsaP256, KeyAlgorithm::Ed25519) => std::cmp::Ordering::Less,
228 (KeyAlgorithm::Ed25519, KeyAlgorithm::Rsa) => std::cmp::Ordering::Greater,
229 (KeyAlgorithm::Ed25519, KeyAlgorithm::EcdsaP256) => std::cmp::Ordering::Greater,
230 _ => a.file_path.cmp(&b.file_path),
231 }
232 });
233
234 Ok(keys)
235 }
236
237 pub fn load_public_key_from_path<P: AsRef<Path>>(&self, path: P) -> Result<HybridPublicKey, SshKeyError> {
239 self.load_public_key(path.as_ref())
240 }
241
242 fn load_public_key(&self, path: &Path) -> Result<HybridPublicKey, SshKeyError> {
244 let content = fs::read_to_string(path)?;
245 let ssh_key = SshPublicKey::from_openssh(&content)
246 .map_err(|e| SshKeyError::InvalidKeyFormat(format!("{}: {}", path.display(), e)))?;
247
248 HybridPublicKey::new(ssh_key, path.to_path_buf())
249 }
250
251 pub fn get_default_key(&self) -> Result<HybridPublicKey, SshKeyError> {
253 let keys = self.discover_keys()?;
254
255 if let Some(rsa_key) = keys.iter().find(|k| k.algorithm == KeyAlgorithm::Rsa) {
257 Ok(rsa_key.clone())
258 } else if let Some(ecdsa_key) = keys.iter().find(|k| k.algorithm == KeyAlgorithm::EcdsaP256) {
259 Ok(ecdsa_key.clone())
260 } else if let Some(first_key) = keys.into_iter().next() {
261 Ok(first_key)
262 } else {
263 Err(SshKeyError::NoPublicKeysFound)
264 }
265 }
266
267 pub fn find_keys_by_algorithm(&self, algorithm: KeyAlgorithm) -> Result<Vec<HybridPublicKey>, SshKeyError> {
269 let keys = self.discover_keys()?;
270 let filtered: Vec<_> = keys.into_iter()
271 .filter(|k| k.algorithm == algorithm)
272 .collect();
273
274 if filtered.is_empty() {
275 Err(SshKeyError::NoPublicKeysFound)
276 } else {
277 Ok(filtered)
278 }
279 }
280
281 pub fn discover_private_keys(&self) -> Result<Vec<HybridPrivateKey>, SshKeyError> {
283 if !self.ssh_dir.exists() {
284 return Err(SshKeyError::NoSshDirectory);
285 }
286
287 let mut keys = Vec::new();
288 let entries = fs::read_dir(&self.ssh_dir)?;
289
290 for entry in entries {
291 let entry = entry?;
292 let path = entry.path();
293
294 if path.is_file() && !path.extension().map_or(false, |ext| ext == "pub") {
296 let filename = path.file_name()
298 .and_then(|n| n.to_str())
299 .unwrap_or("");
300
301 if filename.starts_with("known_hosts") ||
302 filename.starts_with("config") ||
303 filename.starts_with("authorized_keys") {
304 continue;
305 }
306
307 match self.load_private_key(&path) {
308 Ok(key) => {
309 println!("š Found private key: {}", key.display_name());
310 keys.push(key);
311 },
312 Err(e) => {
313 eprintln!("Warning: Failed to load private key {}: {}", path.display(), e);
315 }
316 }
317 }
318 }
319
320 if keys.is_empty() {
321 return Err(SshKeyError::NoPublicKeysFound);
322 }
323
324 keys.sort_by(|a, b| {
326 match (&a.algorithm, &b.algorithm) {
327 (KeyAlgorithm::Rsa, KeyAlgorithm::EcdsaP256) => std::cmp::Ordering::Less,
328 (KeyAlgorithm::Rsa, KeyAlgorithm::Ed25519) => std::cmp::Ordering::Less,
329 (KeyAlgorithm::EcdsaP256, KeyAlgorithm::Rsa) => std::cmp::Ordering::Greater,
330 (KeyAlgorithm::EcdsaP256, KeyAlgorithm::Ed25519) => std::cmp::Ordering::Less,
331 (KeyAlgorithm::Ed25519, KeyAlgorithm::Rsa) => std::cmp::Ordering::Greater,
332 (KeyAlgorithm::Ed25519, KeyAlgorithm::EcdsaP256) => std::cmp::Ordering::Greater,
333 _ => a.file_path.cmp(&b.file_path),
334 }
335 });
336
337 Ok(keys)
338 }
339
340 pub fn load_private_key_from_path<P: AsRef<Path>>(&self, path: P) -> Result<HybridPrivateKey, SshKeyError> {
342 self.load_private_key(path.as_ref())
343 }
344
345 fn load_private_key(&self, path: &Path) -> Result<HybridPrivateKey, SshKeyError> {
347 let content = fs::read_to_string(path)?;
348
349 let ssh_key = SshPrivateKey::from_openssh(&content)
351 .map_err(|e| SshKeyError::InvalidKeyFormat(format!("{}: {}", path.display(), e)))?;
352
353 HybridPrivateKey::new(ssh_key, path.to_path_buf())
354 }
355
356 pub fn get_default_private_key(&self) -> Result<HybridPrivateKey, SshKeyError> {
358 let keys = self.discover_private_keys()?;
359
360 if let Some(rsa_key) = keys.iter().find(|k| k.algorithm == KeyAlgorithm::Rsa) {
362 Ok(rsa_key.clone())
363 } else if let Some(ecdsa_key) = keys.iter().find(|k| k.algorithm == KeyAlgorithm::EcdsaP256) {
364 Ok(ecdsa_key.clone())
365 } else if let Some(first_key) = keys.into_iter().next() {
366 Ok(first_key)
367 } else {
368 Err(SshKeyError::NoPublicKeysFound)
369 }
370 }
371
372 pub fn find_private_keys_by_algorithm(&self, algorithm: KeyAlgorithm) -> Result<Vec<HybridPrivateKey>, SshKeyError> {
374 let keys = self.discover_private_keys()?;
375 let filtered: Vec<_> = keys.into_iter()
376 .filter(|k| k.algorithm == algorithm)
377 .collect();
378
379 if filtered.is_empty() {
380 Err(SshKeyError::NoPublicKeysFound)
381 } else {
382 Ok(filtered)
383 }
384 }
385
386 pub fn check_ssh_directory(&self) -> Result<(), SshKeyError> {
388 if !self.ssh_dir.exists() {
389 return Err(SshKeyError::NoSshDirectory);
390 }
391
392 fs::read_dir(&self.ssh_dir)?;
394 Ok(())
395 }
396
397 pub fn select_public_key_interactive(&self) -> Result<HybridPublicKey, SshKeyError> {
399 let keys = self.discover_keys()?;
400
401 if keys.is_empty() {
402 return Err(SshKeyError::NoPublicKeysFound);
403 }
404
405 if keys.len() == 1 {
406 println!("š Using public key: {}", keys[0].display_name());
407 return Ok(keys[0].clone());
408 }
409
410 println!("\nš Multiple public keys found in ~/.ssh:");
412 for (index, key) in keys.iter().enumerate() {
413 println!(" [{}] {}", index + 1, key.display_name());
414 }
415
416 loop {
417 print!("\nSelect a key (1-{}): ", keys.len());
418 use std::io::{self, Write};
419 io::stdout().flush().unwrap();
420
421 let mut input = String::new();
422 io::stdin().read_line(&mut input).map_err(|e| {
423 SshKeyError::IoError(e)
424 })?;
425
426 if let Ok(selection) = input.trim().parse::<usize>() {
427 if selection >= 1 && selection <= keys.len() {
428 let selected_key = &keys[selection - 1];
429 println!("ā
Selected: {}", selected_key.display_name());
430 return Ok(selected_key.clone());
431 }
432 }
433
434 println!("ā Invalid selection. Please enter a number between 1 and {}.", keys.len());
435 }
436 }
437
438 pub fn select_private_key_interactive(&self) -> Result<HybridPrivateKey, SshKeyError> {
440 let keys = self.discover_private_keys()?;
441
442 if keys.is_empty() {
443 return Err(SshKeyError::NoPublicKeysFound);
444 }
445
446 if keys.len() == 1 {
447 println!("š Using private key: {}", keys[0].display_name());
448 return Ok(keys[0].clone());
449 }
450
451 println!("\nš Multiple private keys found in ~/.ssh:");
453 for (index, key) in keys.iter().enumerate() {
454 println!(" [{}] {}", index + 1, key.display_name());
455 }
456
457 loop {
458 print!("\nSelect a key (1-{}): ", keys.len());
459 use std::io::{self, Write};
460 io::stdout().flush().unwrap();
461
462 let mut input = String::new();
463 io::stdin().read_line(&mut input).map_err(|e| {
464 SshKeyError::IoError(e)
465 })?;
466
467 if let Ok(selection) = input.trim().parse::<usize>() {
468 if selection >= 1 && selection <= keys.len() {
469 let selected_key = &keys[selection - 1];
470 println!("ā
Selected: {}", selected_key.display_name());
471 return Ok(selected_key.clone());
472 }
473 }
474
475 println!("ā Invalid selection. Please enter a number between 1 and {}.", keys.len());
476 }
477 }
478
479 pub fn generate_key_pair(
481 &self,
482 algorithm: KeyAlgorithm,
483 key_size: Option<usize>,
484 comment: Option<String>,
485 output_path: Option<PathBuf>,
486 ) -> Result<(PathBuf, PathBuf), SshKeyError> {
487 use std::process::Command;
488
489 let key_type = match algorithm {
493 KeyAlgorithm::Rsa => "rsa",
494 KeyAlgorithm::EcdsaP256 => "ecdsa",
495 KeyAlgorithm::Ed25519 => "ed25519",
496 };
497
498 let (private_path, public_path) = if let Some(base_path) = output_path {
500 let private_path = base_path.clone();
501 let public_path = base_path.with_extension("pub");
502 (private_path, public_path)
503 } else {
504 let key_name = match algorithm {
506 KeyAlgorithm::Rsa => "id_rsa_sf_cli",
507 KeyAlgorithm::EcdsaP256 => "id_ecdsa_sf_cli",
508 KeyAlgorithm::Ed25519 => "id_ed25519_sf_cli",
509 };
510 let private_path = self.ssh_dir.join(key_name);
511 let public_path = self.ssh_dir.join(format!("{}.pub", key_name));
512 (private_path, public_path)
513 };
514
515 if let Some(parent) = private_path.parent() {
517 fs::create_dir_all(parent)?;
518 }
519
520 let mut cmd = Command::new("ssh-keygen");
522 cmd.arg("-t").arg(key_type)
523 .arg("-f").arg(&private_path)
524 .arg("-N").arg("") .arg("-q"); if algorithm == KeyAlgorithm::Rsa {
529 let bits = key_size.unwrap_or(3072);
530 if bits < 2048 {
531 return Err(SshKeyError::KeyGenerationFailed(
532 "RSA key size must be at least 2048 bits".to_string()
533 ));
534 }
535 cmd.arg("-b").arg(bits.to_string());
536 }
537
538 if algorithm == KeyAlgorithm::EcdsaP256 {
540 cmd.arg("-b").arg("256");
541 }
542
543 if let Some(comment_str) = comment {
545 cmd.arg("-C").arg(comment_str);
546 } else {
547 cmd.arg("-C").arg("sf-cli-generated");
548 }
549
550 println!("š Generating {} key pair...", algorithm);
551
552 let output = cmd.output()
554 .map_err(|e| SshKeyError::KeyGenerationFailed(format!("Failed to execute ssh-keygen: {}", e)))?;
555
556 if !output.status.success() {
557 let error_msg = String::from_utf8_lossy(&output.stderr);
558 return Err(SshKeyError::KeyGenerationFailed(
559 format!("ssh-keygen failed: {}", error_msg)
560 ));
561 }
562
563 if !private_path.exists() {
565 return Err(SshKeyError::KeyGenerationFailed(
566 format!("Private key file was not created: {}", private_path.display())
567 ));
568 }
569 if !public_path.exists() {
570 return Err(SshKeyError::KeyGenerationFailed(
571 format!("Public key file was not created: {}", public_path.display())
572 ));
573 }
574
575 println!("ā
{} key pair generated successfully:", algorithm);
576 println!(" Private key: {}", private_path.display());
577 println!(" Public key: {}", public_path.display());
578
579 Ok((private_path, public_path))
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586 use tempfile::TempDir;
587 use std::fs;
588
589 #[test]
590 fn test_ssh_key_discovery_no_directory() {
591 let temp_dir = TempDir::new().unwrap();
592 let nonexistent_ssh_dir = temp_dir.path().join("nonexistent");
593
594 let discovery = SshKeyDiscovery::with_ssh_dir(nonexistent_ssh_dir);
595 let result = discovery.discover_keys();
596
597 assert!(matches!(result, Err(SshKeyError::NoSshDirectory)));
598 }
599
600 #[test]
601 fn test_ssh_key_discovery_empty_directory() {
602 let temp_dir = TempDir::new().unwrap();
603 let ssh_dir = temp_dir.path().join(".ssh");
604 fs::create_dir(&ssh_dir).unwrap();
605
606 let discovery = SshKeyDiscovery::with_ssh_dir(ssh_dir);
607 let result = discovery.discover_keys();
608
609 assert!(matches!(result, Err(SshKeyError::NoPublicKeysFound)));
610 }
611
612 #[test]
613 fn test_key_algorithm_display() {
614 assert_eq!(KeyAlgorithm::Rsa.to_string(), "RSA");
615 assert_eq!(KeyAlgorithm::EcdsaP256.to_string(), "ECDSA-P256");
616 assert_eq!(KeyAlgorithm::Ed25519.to_string(), "Ed25519");
617 }
618
619 #[test]
620 fn test_check_ssh_directory() {
621 let temp_dir = TempDir::new().unwrap();
622 let ssh_dir = temp_dir.path().join(".ssh");
623 fs::create_dir(&ssh_dir).unwrap();
624
625 let discovery = SshKeyDiscovery::with_ssh_dir(&ssh_dir);
626 assert!(discovery.check_ssh_directory().is_ok());
627
628 let nonexistent = temp_dir.path().join("nonexistent");
629 let discovery2 = SshKeyDiscovery::with_ssh_dir(nonexistent);
630 assert!(matches!(discovery2.check_ssh_directory(), Err(SshKeyError::NoSshDirectory)));
631 }
632
633 }