Skip to main content

ssh_agent_switcher/
lib.rs

1// Copyright 2025 Julio Merino.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without modification, are permitted
5// provided that the following conditions are met:
6//
7// * Redistributions of source code must retain the above copyright notice, this list of conditions
8//   and the following disclaimer.
9// * Redistributions in binary form must reproduce the above copyright notice, this list of
10//   conditions and the following disclaimer in the documentation and/or other materials provided with
11//   the distribution.
12// * Neither the name of ssh-agent-switcher nor the names of its contributors may be used to endorse
13//   or promote products derived from this software without specific prior written permission.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17// FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
18// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
21// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
22// WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23
24//! Serves a Unix domain socket that proxies connections to any valid SSH agent provided by sshd.
25
26use log::{debug, info, warn};
27use std::env;
28use std::fs;
29use std::io;
30use std::path::{Path, PathBuf};
31use std::thread;
32use std::time::Duration;
33use tokio::net::{UnixListener, UnixStream};
34use tokio::select;
35use tokio::signal::unix::{SignalKind, signal};
36
37mod find;
38
39/// Result type for this crate.
40type Result<T> = std::result::Result<T, String>;
41
42/// A scope guard to restore the previous umask.
43struct UmaskGuard {
44    old_umask: libc::mode_t,
45}
46
47impl Drop for UmaskGuard {
48    fn drop(&mut self) {
49        let _ = unsafe { libc::umask(self.old_umask) };
50    }
51}
52
53/// Sets the umask and returns a guard to restore it on drop.
54fn set_umask(umask: libc::mode_t) -> UmaskGuard {
55    UmaskGuard { old_umask: unsafe { libc::umask(umask) } }
56}
57
58/// Creates the agent socket to listen on.
59///
60/// This makes sure that the socket is only accessible by the current user.
61fn create_listener(socket_path: &Path) -> Result<UnixListener> {
62    // Ensure the socket is not group nor world readable so that we don't expose the real socket
63    // indirectly to other users.
64    let _guard = set_umask(0o177);
65
66    UnixListener::bind(socket_path)
67        .map_err(|e| format!("Cannot listen on {}: {}", socket_path.display(), e))
68}
69
70/// Handles one incoming connection on `client`.
71async fn handle_connection(
72    mut client: UnixStream,
73    agents_dirs: &[PathBuf],
74    home: Option<&Path>,
75    uid: libc::uid_t,
76) -> Result<()> {
77    let mut agent = match find::find_socket(agents_dirs, home, uid).await {
78        Some(socket) => socket,
79        None => {
80            return Err("No agent found; cannot proxy request".to_owned());
81        }
82    };
83    let result = tokio::io::copy_bidirectional(&mut client, &mut agent)
84        .await
85        .map(|_| ())
86        .map_err(|e| format!("{}", e));
87    debug!("Closing client connection");
88    result
89}
90
91/// Runs the core logic of the app.
92///
93/// This serves the SSH agent socket on `socket_path` and looks for sshd sockets in `agents_dirs`.
94///
95/// The `pid_file` needs to be passed in for cleanup purposes.
96pub async fn run(socket_path: PathBuf, agents_dirs: &[PathBuf], pid_file: PathBuf) -> Result<()> {
97    let home = env::var("HOME").map(|v| Some(PathBuf::from(v))).unwrap_or(None);
98    let uid = unsafe { libc::getuid() };
99
100    let mut sighup = signal(SignalKind::hangup())
101        .map_err(|e| format!("Failed to install SIGHUP handler: {}", e))?;
102    let mut sigint = signal(SignalKind::interrupt())
103        .map_err(|e| format!("Failed to install SIGINT handler: {}", e))?;
104    let mut sigquit = signal(SignalKind::quit())
105        .map_err(|e| format!("Failed to install SIGQUIT handler: {}", e))?;
106    let mut sigterm = signal(SignalKind::terminate())
107        .map_err(|e| format!("Failed to install SIGTERM handler: {}", e))?;
108
109    let listener = create_listener(&socket_path)?;
110
111    debug!("Entering main loop");
112    let mut stop = None;
113    while stop.is_none() {
114        select! {
115            result = listener.accept() => match result {
116                Ok((socket, _addr)) => {
117                    debug!("Connection accepted");
118                    // TODO(jmmv): Connections are handled sequentially.  This is... fine.
119                    if let Err(e) = handle_connection(socket, agents_dirs, home.as_deref(), uid).await {
120                        warn!("Dropping connection due to error: {}", e);
121                    }
122                }
123                Err(e) => warn!("Failed to accept connection: {}", e),
124            },
125
126            _ = sighup.recv() => (),
127            _ = sigint.recv() => stop = Some("SIGINT"),
128            _ = sigquit.recv() => stop = Some("SIGQUIT"),
129            _ = sigterm.recv() => stop = Some("SIGTERM"),
130        }
131    }
132    debug!("Main loop exited");
133
134    let stop = stop.expect("Loop can only exit by setting stop");
135    info!("Shutting down due to {} and removing {}", stop, socket_path.display());
136
137    let _ = fs::remove_file(&socket_path);
138    // Because we catch signals, daemonize doesn't properly clean up the PID file so we have
139    // to do it ourselves.
140    let _ = fs::remove_file(&pid_file);
141
142    Ok(())
143}
144
145/// Waits for `path` to exist for a maximum period of time using operation `op`.
146/// Returns the result of `op` on success.
147pub fn wait_for_file<P: AsRef<Path> + Copy, T>(
148    path: P,
149    mut pending_wait: Duration,
150    op: fn(P) -> io::Result<T>,
151) -> Result<T> {
152    while pending_wait > Duration::ZERO {
153        match op(path) {
154            Ok(result) => {
155                return Ok(result);
156            }
157            Err(e) if e.kind() == io::ErrorKind::NotFound => {
158                thread::sleep(Duration::from_millis(1));
159                pending_wait -= Duration::from_millis(1);
160            }
161            Err(e) => {
162                return Err(e.to_string());
163            }
164        }
165    }
166    Err("File was not created on time".to_owned())
167}