ssh_agent_switcher/
lib.rs1use log::{debug, info, warn};
27use std::env;
28use std::fs;
29use std::io;
30use std::path::{Path, PathBuf};
31use std::thread;
32use std::time::Duration;
33use tokio::net::{UnixListener, UnixStream};
34use tokio::select;
35use tokio::signal::unix::{SignalKind, signal};
36
37mod find;
38
39type Result<T> = std::result::Result<T, String>;
41
42struct UmaskGuard {
44 old_umask: libc::mode_t,
45}
46
47impl Drop for UmaskGuard {
48 fn drop(&mut self) {
49 let _ = unsafe { libc::umask(self.old_umask) };
50 }
51}
52
53fn set_umask(umask: libc::mode_t) -> UmaskGuard {
55 UmaskGuard { old_umask: unsafe { libc::umask(umask) } }
56}
57
58fn create_listener(socket_path: &Path) -> Result<UnixListener> {
62 let _guard = set_umask(0o177);
65
66 UnixListener::bind(socket_path)
67 .map_err(|e| format!("Cannot listen on {}: {}", socket_path.display(), e))
68}
69
70async fn handle_connection(
72 mut client: UnixStream,
73 agents_dirs: &[PathBuf],
74 home: Option<&Path>,
75 uid: libc::uid_t,
76) -> Result<()> {
77 let mut agent = match find::find_socket(agents_dirs, home, uid).await {
78 Some(socket) => socket,
79 None => {
80 return Err("No agent found; cannot proxy request".to_owned());
81 }
82 };
83 let result = tokio::io::copy_bidirectional(&mut client, &mut agent)
84 .await
85 .map(|_| ())
86 .map_err(|e| format!("{}", e));
87 debug!("Closing client connection");
88 result
89}
90
91pub async fn run(socket_path: PathBuf, agents_dirs: &[PathBuf], pid_file: PathBuf) -> Result<()> {
97 let home = env::var("HOME").map(|v| Some(PathBuf::from(v))).unwrap_or(None);
98 let uid = unsafe { libc::getuid() };
99
100 let mut sighup = signal(SignalKind::hangup())
101 .map_err(|e| format!("Failed to install SIGHUP handler: {}", e))?;
102 let mut sigint = signal(SignalKind::interrupt())
103 .map_err(|e| format!("Failed to install SIGINT handler: {}", e))?;
104 let mut sigquit = signal(SignalKind::quit())
105 .map_err(|e| format!("Failed to install SIGQUIT handler: {}", e))?;
106 let mut sigterm = signal(SignalKind::terminate())
107 .map_err(|e| format!("Failed to install SIGTERM handler: {}", e))?;
108
109 let listener = create_listener(&socket_path)?;
110
111 debug!("Entering main loop");
112 let mut stop = None;
113 while stop.is_none() {
114 select! {
115 result = listener.accept() => match result {
116 Ok((socket, _addr)) => {
117 debug!("Connection accepted");
118 if let Err(e) = handle_connection(socket, agents_dirs, home.as_deref(), uid).await {
120 warn!("Dropping connection due to error: {}", e);
121 }
122 }
123 Err(e) => warn!("Failed to accept connection: {}", e),
124 },
125
126 _ = sighup.recv() => (),
127 _ = sigint.recv() => stop = Some("SIGINT"),
128 _ = sigquit.recv() => stop = Some("SIGQUIT"),
129 _ = sigterm.recv() => stop = Some("SIGTERM"),
130 }
131 }
132 debug!("Main loop exited");
133
134 let stop = stop.expect("Loop can only exit by setting stop");
135 info!("Shutting down due to {} and removing {}", stop, socket_path.display());
136
137 let _ = fs::remove_file(&socket_path);
138 let _ = fs::remove_file(&pid_file);
141
142 Ok(())
143}
144
145pub fn wait_for_file<P: AsRef<Path> + Copy, T>(
148 path: P,
149 mut pending_wait: Duration,
150 op: fn(P) -> io::Result<T>,
151) -> Result<T> {
152 while pending_wait > Duration::ZERO {
153 match op(path) {
154 Ok(result) => {
155 return Ok(result);
156 }
157 Err(e) if e.kind() == io::ErrorKind::NotFound => {
158 thread::sleep(Duration::from_millis(1));
159 pending_wait -= Duration::from_millis(1);
160 }
161 Err(e) => {
162 return Err(e.to_string());
163 }
164 }
165 }
166 Err("File was not created on time".to_owned())
167}